1. 程式人生 > >防止過擬合的基本方法

防止過擬合的基本方法

過擬合是訓練神經網路中常見的問題,本文討論了產生過擬合的原因,如何發現過擬合,以及簡單的解決方法。

發現過擬合問題

在訓練神經網路時,我們常常有訓練集、測試集和驗證集三種資料集。

有時候我們會發現,訓練出來的神經網路在訓練集上表現很好(準確率很高),但在測試集上的準確率比較差。這種現象一般被認為是過擬合,也就是過度學習了訓練集上的特徵,導致泛化能力較差。

hold out 方法

那麼如何發現是否存在過擬合方法呢?一種簡單的思路就是把訓練集分為訓練集和驗證集,其中訓練集用來訓練資料,驗證集用來檢測準確率。

我們在每個迭代期的最後都計算在驗證集上的分類準確率,一旦分類準確率已經飽和,就停止訓練。這個策略被稱為提前停止

示例

以MNIST資料集為例,這裡使用1000個樣本作為訓練集,迭代週期為400,使用交叉熵代價函式,隨機梯度下降,我們可以畫出其損失值與準確率。

訓練集上的損失值和準確率:

驗證集上的損失值和準確率:

對比測試集與驗證集的準確率:

可以發現:訓練集上的損失值越來越小,正確率已經達到了100%,而驗證集上的損失會突然增大,正確率沒有提升。這就產生了過擬合問題。

增大訓練量

一個最直觀,也是最有效的方式就是增大訓練量。有了足夠的訓練資料,就算是一個規模很大的網路也不太容易過擬合。

例如,如果我們將MNIST的訓練資料增大到50000(擴大了50倍),則可以發現訓練集和測試集的正確率差距不大,且一直在增加(這裡只迭代了30次):

但很不幸,一般來說,訓練資料時有限的,這種方法不太實際。

人為擴充套件訓練資料

當我們缺乏訓練資料時,可以使用一種巧妙的方式人為構造資料。

例如,對於MNIST手寫數字資料集,我們可以將每幅影象左右旋轉15°。這應該還是被識別成同樣的數字,但對於我們的神經網路來說(畫素級),這就是完全不同的輸入。

因此,將這些樣本加入到訓練資料中很可能幫助我們的網路學習更多如何分類數字。

這個想法很強大並且已經被廣泛應用了,更多討論可以檢視這篇論文

再舉個例子,當我們訓練神經網路進行語音識別時,我們可以對這些語音隨機加上一些噪音–加速或減速。

規範化(regularization)

除了增大訓練樣本,另一種能減輕過擬合的方法是降低網路的規模。但往往大規模的神經網路有更強的潛力,因此我們想使用另外的技術。

規範化是神經網路中常用的方法,雖然沒有足夠的理論,但規範化的神經網路往往能夠比非規範化的泛化能力更強。

一般來說,我們只需要對w進行規範化,而幾乎不對b進行規範化。

L2規範化

學習規則

最常用的規範化手段,也稱為權重衰減(weight decay)。

L2規範化的想法是增加一個額外的項到代價函式上,這個項被稱為規範化項。例如,對於規範化的交叉熵:

C=1nx[yjlnajL+(1yj)ln(1ajL)]+λ2nww2

對於其他形式的代價函式,都可以寫成:

C=C0+λ2nww2

由於我們的目的是使得代價函式越小越好,因此直覺的看,規範化的效果是讓網路傾向於學習小一點的權重。

換言之,規範化可以當做一種尋找小的權重和最小化原始代價函式之間的折中。

現在,我們再對wb求偏導:

Cw=C0w+λnw Cw=C0b

因此,我們計算規範化的代價函式的梯度是很簡單的:僅僅需要反向傳播,然後加上λnw得到所有權重的偏導數。而偏置的偏導數不需要變化。所以權重的學習規則為:

w(1λn)wηmxCxw bbηmxCxb

這裡也表明,我們傾向於使得權重更小一點。

那這樣,是否會讓權重不斷下降變為0呢?但實際上不是這樣的,因為如果在原始代價函式中的下降會造成其他項使得權重增加。

示例

我們依然來看MNIST的例子。這裡,我使用λ=0.1的規範化項進行學習。

訓練集上的準確率和損失值和之前一樣:

測試集上的損失值不斷減少,準確率不斷提高,符合預期:

L1規範化

學習規則

這個方法是在未規範化的代價函式上加一個權重絕對值的和:

C=C0+λnw|w|

對其進行求偏導得:

Cw=C0w+λnsgn(w)

其中sgn()就是w的正負號。

與L2規範化的聯絡

我們將L1規範化與L2規範化進行對比: