防止過擬合的方法 預測鸞鳳花(sklearn)
1. 防止過擬合的方法有哪些?
過擬合(overfitting)是指在模型參數擬合過程中的問題,由於訓練數據包含抽樣誤差,訓練時,復雜的模型將抽樣誤差也考慮在內,將抽樣誤差也進行了很好的擬合。
產生過擬合問題的原因大體有兩個:訓練樣本太少或者模型太復雜。
防止過擬合問題的方法:
(1)增加訓練數據。
考慮增加訓練樣本的數量
使用數據集估計數據分布參數,使用估計分布參數生成訓練樣本
使用數據增強
(2)減小模型的復雜度。
a.減少網絡的層數或者神經元數量。這個很好理解,介紹網絡的層數或者神經元的數量會使模型的擬合能力降低。
b.參數範數懲罰。參數範數懲罰通常采用L1和L2參數正則化(關於L1和L2的區別聯系請戳這裏)。
c.提前終止(Early stopping);
d.添加噪聲。添加噪聲可以在輸入、權值,網絡相應中添加。
e.結合多種模型。這種方法中使用不同的模型擬合不同的數據集,例如使用 Bagging,Boosting,Dropout、貝葉斯方法
而在深度學習中,通常解決的方法如下
Early stopping方法的具體做法是,在每一個Epoch結束時(一個Epoch集為對所有的訓練數據的一輪遍歷)計算validation data的accuracy,當accuracy不再提高時,就停止訓練。
獲取更多數據(從數據源頭獲取更多數據 根據當前數據集估計數據分布參數,使用該分布產生更多數據 數據增強(Data Augmentation))
正則化(直接將權值的大小加入到 Cost 裏,在訓練的時候限制權值變大)
dropout:在訓練時,每次隨機(如50%概率)忽略隱層的某些節點;
2. 使用邏輯回歸(Logistic Regression)對鳶尾花數據(多分類問題)進行預測,可以直接使用sklearn中的LR方法,並嘗試使用不同的參數,包括正則化的方法,正則項系數,求解優化器,以及將二分類模型轉化為多分類模型的方法。
獲取鳶尾花數據的方法:
from sklearn.datasets import load_iris
X, y = load_iris(return_X_y=True)
print(__doc__) # Code source: Ga?l Varoquaux # Modified for documentation by Jaques Grobler # License: BSD 3 clause import numpy as np import matplotlib.pyplot as plt from sklearn import linear_model, datasets # import some data to play with iris = datasets.load_iris() X = iris.data[:, :2] #we only take the first two features. Y = iris.target h = .02 # step size in the mesh logreg = linear_model.LogisticRegression(C=1e5) # we create an instance of Neighbours Classifier and fit the data. logreg.fit(X, Y) # Plot the decision boundary. For that, we will assign a color to each # point in the mesh [x_min, x_max]x[y_min, y_max]. x_min, x_max = X[:, 0].min() - .5, X[:, 0].max() + .5 y_min, y_max = X[:, 1].min() - .5, X[:, 1].max() + .5 xx, yy = np.meshgrid(np.arange(x_min, x_max, h), np.arange(y_min, y_max, h)) Z = logreg.predict(np.c_[xx.ravel(), yy.ravel()]) # Put the result into a color plot Z = Z.reshape(xx.shape) plt.figure(1, figsize=(4, 3)) plt.pcolormesh(xx, yy, Z, cmap=plt.cm.Paired) # Plot also the training points plt.scatter(X[:, 0], X[:, 1], c=Y, edgecolors=‘k‘, cmap=plt.cm.Paired) plt.xlabel(‘Sepal length‘) plt.ylabel(‘Sepal width‘) plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) plt.xticks(()) plt.yticks(()) plt.show()
防止過擬合的方法 預測鸞鳳花(sklearn)