1. 程式人生 > >RNN, LSTM, GRU 公式總結

RNN, LSTM, GRU 公式總結

Vanilla RNN

參考 RNN wiki 的描述,根據隱層 ht 接受的是上時刻的隱層(hidden layer) ht1 還是上時刻的輸出(output layer)yt1,分成了兩種 RNN,定義如下:

RNN wiki
  • Elman network 接受上時刻的隱層 ht1
  • Jordan network 接受上時刻的輸出 yt1
RNN from Nature magazine

Bidirectional RNNs

雙向的 RNN 是同時考慮“過去”和“未來”的資訊,考慮上圖,正常情況下,輸入(黑色點)沿著黑色的實線箭頭傳輸到隱層(黃色點),再沿著紅色實線傳到輸出(紅色點)。黑色實線做完前向傳播後,在 Bidirectional RNNs 卻先不急著後向傳播,而是從末尾的時刻沿著虛線的方向再回傳回來。最後把兩個方向得到的啟用值拼在一起(concatenate),當做最後的啟用值。那麼後向傳播也是類似,要轉一圈回來。

Stacked Bidirectional RNNs


堆多層的 recurrent layer,如上圖所示,可以增加模型的引數,提高模型的學習能力。每層的 hidden state 不僅要輸給下一時刻,還是當做是此時刻下一層的輸入。上圖展示了雙向的三層 RNNs,那麼 hidden state 的維度是 hidden_dim * 6,輸出的維度為 hidden_dim * 2,因為是兩個方向最有一層 hidden state 拼接的結果。

原始的 RNN 很難訓練,主要是因為存在梯度消失(gradient vanishing problem)和梯度爆炸問題(gradient explosion problem)。梯度消失導致無法抓住長時刻依賴,因此效果不好,後面的 LSTM 和 GRU 的新結構,就是為了對付這個問題。而梯度爆炸問題雖然不是每次都出現,但是一旦出現就很致命。一般會選擇用截斷的梯度(clipped gradient)來更新引數,或者直接把梯度 rescale 到一個固定模大小的範圍。

LSTM

由於 Vanilla RNN 具有梯度消失問題,對長關係的依賴(Long-Term Dependencies)的建模能力不夠強大。這句話是什麼意思呢?就是說,原來的 RNN,由於結構上的限制,很長的時刻以前的輸入,對現在的網路影響非常小,後向傳播時那些梯度,也很難影響很早以前的輸入,即會出現梯度消失的問題。而 LSTM 通過構建一些門(Gate),讓網路能記住那些非常重要的資訊,而這個核心的結構,就是 cell state。比如遺忘門,來選擇性清空過去的記憶和更新較新的資訊。

上面講的比較迷糊,如果我有新的理解會更新這個部落格。另外可以參考大神的部落格 Understanding LSTM Networks

,把 LSTM 講的深入淺出,並且提到了很多的變種和展望。

有兩種常見的 LSTM 結構,如 LSTM wiki 總結的,第一種是帶遺忘門的 Traditional LSTM,公式如下:

traditional lstm

前三行是三個門,分別是遺忘門 ft,輸入門 it,輸出門 ot,輸入都是 [xt,ht1],只是引數不同,然後要經過一個啟用函式,把值放縮到 [0,1] 附近。第四行 ct 是 cell state,由上一時刻的 ct1 和輸入得到。如果遺忘門 ft 取 0 的話,那麼上一時刻的狀態就會全部被清空(清空 or 遺忘?),然後只關注此時刻的輸入。輸入門 it 決定是否接收此時刻的輸入。最後輸出門 ot 決定是否輸出 cell state。

注意這裡的輸出 ht 只是對應上面 RNN 的隱層,而非輸出。這裡的輸出 ht 又會被當做是下一時刻的輸入。

有時候第四個公式裡的 σ(Wcxt+Ucht1+bc) 可以單獨抽出來,寫作 c˜,叫做 new memory content,那麼第四個公式就可以寫作是 ct=ftct1+itc˜,這樣一來 cell state 的更新來源就很明顯了,一部分是上時刻的自己,一部分是新的 new memory content,而且兩個來源是相互獨立地由兩個門控制的。遺忘門控制是否記住以前的那些特徵,輸入門決定是否接收當前的輸入。後面可以看到 GRU 其實把這兩個門合二為一了。

第二種是帶遺忘門的 Peephole LSTM,公式如下,

Peephole LSTM

和上面的公式做比較,發現只是把 ht1 都換成了 ct1,即三個門的輸入都改成了 [xt,ct1]。因為是從 cell state 裡取得資訊,所以叫窺視孔(peephole)。

還有把兩種結構結合起來的,可以用下圖描述,

圖裡的連著門的那些虛線就是窺視孔。三個輸入分別是 [xt,ht1,ct1] 。上圖引自 Alex Graves 的論文 Supervised Sequence Labelling with Recurrent Neural Networks 中對 LSTM 的描述。注意該論文裡的輸出門和其他兩個門稍稍不同,接受的是 ct,而非 ct1,我沒有找到這樣做的解釋。

GRU

GRU 的結構和 LSTM 類似,但是精簡一些,見下圖

GRU

公式如下:

zt=σ(Wzxt+Uzht1)rt=σ(Wtxt+Utht1)h˜t=tanh(Wxt+U(rtht1))ht=(1zt)ht1+zth˜t

這四行公式解釋如下:

  • zt 是 update gate,更新 activation 時的邏輯閘
  • rt 是 reset gate,決定 candidate activation 時,