迴圈神經網路的訓練(2)
權重梯度的計算
現在,我們終於來到了BPTT演算法的最後一步:計算每個權重的梯度。
首先,我們計算誤差函式E對權重矩陣W的梯度∂E∂W。
上圖展示了我們到目前為止,在前兩步中已經計算得到的量,包括每個時刻t 迴圈層的輸出值st,以及誤差項δt。
回憶一下我們在文章零基礎入門深度學習(3) - 神經網路和反向傳播演算法介紹的全連線網路的權重梯度計算演算法:只要知道了任意一個時刻的誤差項δt,以及上一個時刻迴圈層的輸出值st−1,就可以按照下面的公式求出權重矩陣在t時刻的梯度∇WtE:
∇WtE=⎡⎣⎢⎢⎢⎢⎢⎢δt1st−11δt2st−11..δtnst−11δt1st−12δ
在式5中,δti表示t時刻誤差項向量的第i個分量;st−1i表示t-1時刻迴圈層第i個神經元的輸出值。
我們下面可以簡單推導一下式5。
我們知道:
nett=⎡⎣⎢⎢⎢⎢⎢nett1nett2..nettn⎤⎦⎥⎥⎥⎥⎥==Uxt+Wst−1Uxt+⎡⎣⎢⎢⎢⎢w11w21..wn1w12w22wn2.........w1nw2nwnn⎤⎦⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢st−11st−12..st−1n⎤⎦⎥⎥⎥⎥⎥Uxt+⎡⎣⎢⎢⎢⎢⎢w11st−11+w12st
因為對W求導與Uxt無關,我們不再考慮。現在,我們考慮對權重項wji求導。通過觀察上式我們可以看到wji只與nettj有關,所以:
∂E∂wji==∂E∂nettj∂nettj∂wjiδtjst−1i(47)(48)
按照上面的規律就可以生成式5裡面的矩陣。
我們已經求得了權重矩陣W在t時刻的梯度∇WtE,最終的梯度∇WE是各個時刻的梯度之和:
∇WE==∑i=1t∇WiE⎡⎣⎢⎢⎢⎢⎢⎢δ
式6就是計算迴圈層權重矩陣W的梯度的公式。
----------數學公式超高能預警----------
前面已經介紹了∇WE的計算方法,看上去還是比較直觀的。然而,讀者也許會困惑,為什麼最終的梯度是各個時刻的梯度之和呢?我們前面只是直接用了這個結論,實際上這裡面是有道理的,只是這個數學推導比較繞腦子。感興趣的同學可以仔細閱讀接下來這一段,它用到了矩陣對矩陣求導、張量與向量相乘運算的一些法則。
我們還是從這個式子開始:
nett=Uxt+Wf(nett−1)
因為Uxt與W完全無關,我們把它看做常量。現在,考慮第一個式子加號右邊的部分,因為W和f(nett−1)都是W的函式,因此我們要用到大學裡面都學過的導數乘法運算:
(uv)′=u′v+uv′
因此,上面第一個式子寫成:
∂nett∂W=∂W∂Wf(nett−1)+W∂f(nett−1)∂W
我們最終需要計算的是∇WE:
∇WE===∂E∂W∂E∂nett∂nett∂WδTt∂W∂Wf(nett−1)+δTtW∂f(nett−1)∂W(式7)(51)(52)(53)
我們先計算式7加號左邊的部分。∂W∂W是矩陣對矩陣求導,其結果是一個四維張量(tensor),如下所示:
∂W∂W===⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂w11∂W∂w21∂W..∂wn1∂W∂w12∂W∂w22∂W∂wn2∂W.........∂w1n∂W∂w2n∂W∂wnn∂W⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂w11∂w11∂w11∂w21..∂w11∂wn1∂w11∂w12∂w11∂w22∂w11∂wn2.........∂w11∂1n∂w11∂2n∂w11∂nn⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥..⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢⎢∂w12∂w11∂w12∂w21..∂w12∂wn1∂w12∂w12∂w12∂w22∂w12∂wn2.........∂w12∂1n∂w12∂2n∂w12∂nn⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥...⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡⎣⎢⎢⎢⎢10..0000.........000⎤⎦⎥⎥⎥⎥..⎡⎣⎢⎢⎢⎢00..0100.........000⎤⎦⎥⎥⎥⎥...⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥(54)(55)(56)
接下來,我們知道st−1=f(nett−1),它是一個列向量。我們讓上面的四維張量與這個向量相乘,得到了一個三維張量,再左乘行向量δTt,最終得到一個矩陣:
δTt∂W∂Wf(nett−1)======δTt∂W∂Wst−1δTt⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡⎣⎢⎢⎢⎢10..0000.........000⎤⎦⎥⎥⎥⎥..⎡⎣⎢⎢⎢⎢00..0100.........000⎤⎦⎥⎥⎥⎥...⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢st−11st−12..st−1n⎤⎦⎥⎥⎥⎥⎥δTt⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡⎣⎢⎢⎢⎢⎢st−110..0⎤⎦⎥⎥⎥⎥⎥..⎡⎣⎢⎢⎢⎢⎢st−120..0⎤⎦⎥⎥⎥⎥⎥...⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥[δt1δt2...δtn]⎡⎣⎢⎢⎢⎢⎢⎢⎢⎢⎡⎣⎢⎢⎢⎢⎢st−110..0⎤⎦⎥⎥⎥⎥⎥..⎡⎣⎢⎢⎢⎢⎢st−120..0⎤⎦⎥⎥⎥⎥⎥...⎤⎦⎥⎥⎥⎥⎥⎥⎥⎥⎡⎣⎢⎢⎢⎢⎢⎢δt1st−11δt2st−11..δtnst−11δt1st−12δt2st−12δtnst−12.........δt1st−1nδt2st−1nδtnst−1n⎤⎦⎥⎥⎥⎥⎥⎥∇WtE(57)(58)(59)(60)(61)(62)
接下來,我們計算式7加號右邊的部分:
δTtW∂f(nett−1)∂W====δTtW∂f(nett−1)∂nett−1∂nett−1∂WδTtWf′(nett−1)∂nett−1∂WδTt∂nett∂nett−1∂nett−1∂WδTt−1∂nett−1∂W(63)(64)(65)(66)
於是,我們得到了如下遞推公式:
∇WE======∂E∂W∂E∂nett∂nett∂W∇WtE+δTt−1∂nett−1∂W∇WtE+∇Wt−1E+δTt−2∂nett−2∂W∇WtE+∇Wt−1E+...+∇W1E∑k=1t∇WkE(67)(68)(69)(70)(71)(72)
這樣,我們就證明了:最終的梯度∇WE是各個時刻的梯度之和。
----------數學公式超高能預警解除----------
同權重矩陣W類似,我們可以得到權重矩陣U的計算方法。
∇UtE=⎡⎣⎢⎢⎢⎢⎢⎢δt1xt1δt2xt1..δtnxt1δt1xt2δt2xt2δtnxt2.........δt1xtmδt2xtmδtnxtm⎤⎦⎥⎥⎥⎥⎥⎥(式8)
式8是誤差函式在t時刻對權重矩陣U的梯度。和權重矩陣W一樣,最終的梯度也是各個時刻的梯度之和:
∇UE=∑i=1t∇UiE
具體的證明這裡就不再贅述了,感興趣的讀者可以練習推導一下。
RNN的梯度爆炸和消失問題
不幸的是,實踐中前面介紹的幾種RNNs並不能很好的處理較長的序列。一個主要的原因是,RNN在訓練中很容易發生梯度爆炸和梯度消失,這導致訓練時梯度不能在較長序列中一直傳遞下去,從而使RNN無法捕捉到長距離的影響。
為什麼RNN會產生梯度爆炸和消失問題呢?我們接下來將詳細分析一下原因。我們根據式3可得:
δTk=∥δTk∥⩽⩽δTt∏i=kt−1Wdiag[f′(neti)]∥δTt∥∏i=kt−1∥W∥∥diag[f′(neti)]∥∥δTt∥(βWβf)t−k(73)(74)(75)
上式的β定義為矩陣的模的上界。因為上式是一個指數函式,如果t-k很大的話(也就是向前看很遠的時候),會導致對應的誤差項的值增長或縮小的非常快,這樣就會導致相應的梯度爆炸和梯度消失問題(取決於β大於1還是小於1)。
通常來說,梯度爆炸更容易處理一些。因為梯度爆炸的時候,我們的程式會收到NaN錯誤。我們也可以設定一個梯度閾值,當梯度超過這個閾值的時候可以直接擷取。
梯度消失更難檢測,而且也更難處理一些。總的來說,我們有三種方法應對梯度消失問題:
- 合理的初始化權重值。初始化權重,使每個神經元儘可能不要取極大或極小值,以躲開梯度消失的區域。
- 使用relu代替sigmoid和tanh作為啟用函式。原理請參考上一篇文章零基礎入門深度學習(4) - 卷積神經網路的啟用函式一節。
- 使用其他結構的RNNs,比如長短時記憶網路(LTSM)和Gated Recurrent Unit(GRU),這是最流行的做法。我們將在以後的文章中介紹這兩種網路。