1. 程式人生 > >詳解 LSTM

詳解 LSTM

轉自:https://www.jianshu.com/p/dcec3f07d3b5


今天的內容有:

  1. LSTM 思路
  2. LSTM 的前向計算
  3. LSTM 的反向傳播
  4. 關於調參

LSTM

長短時記憶網路(Long Short Term Memory Network, LSTM),是一種改進之後的迴圈神經網路,可以解決RNN無法處理長距離的依賴的問題,目前比較流行。

長短時記憶網路的思路:

原始 RNN 的隱藏層只有一個狀態,即h,它對於短期的輸入非常敏感。
再增加一個狀態,即c,讓它來儲存長期的狀態,稱為單元狀態(cell state)。

把上圖按照時間維度展開:

在 t 時刻,LSTM 的輸入有三個:當前時刻網路的輸入值 x_t、上一時刻 LSTM 的輸出值 h_t-1、以及上一時刻的單元狀態 c_t-1
LSTM 的輸出有兩個:當前時刻 LSTM 輸出值 h_t、和當前時刻的單元狀態 c_t.

關鍵問題是:怎樣控制長期狀態 c ?

方法是:使用三個控制開關

第一個開關,負責控制繼續儲存長期狀態c;
第二個開關,負責控制把即時狀態輸入到長期狀態c;
第三個開關,負責控制是否把長期狀態c作為當前的LSTM的輸出。

如何在演算法中實現這三個開關?

方法:用 門(gate)

定義:gate 實際上就是一層全連線層,輸入是一個向量,輸出是一個 0到1 之間的實數向量。
公式為:


回憶一下它的樣子:


gate 如何進行控制?

方法:用門的輸出向量按元素乘以我們需要控制的那個向量
原理:門的輸出是 0到1 之間的實數向量,
當門輸出為 0 時,任何向量與之相乘都會得到 0 向量,這就相當於什麼都不能通過;
輸出為 1 時,任何向量與之相乘都不會有任何改變,這就相當於什麼都可以通過。


LSTM 前向計算

在 LSTM-1 中提到了,模型是通過使用三個控制開關來控制長期狀態 c 的:

這些開關就是用門(gate)來實現:

接下來具體看這三重門


LSTM 的前向計算:

一共有 6 個公式

遺忘門(forget gate)
它決定了上一時刻的單元狀態 c_t-1 有多少保留到當前時刻 c_t

輸入門(input gate)
它決定了當前時刻網路的輸入 x_t 有多少儲存到單元狀態 c_t

輸出門(output gate)
控制單元狀態 c_t 有多少輸出到 LSTM 的當前輸出值 h_t


遺忘門的計算為:

forget

遺忘門的計算公式中:
W_f 是遺忘門的權重矩陣,[h_t-1, x_t] 表示把兩個向量連線成一個更長的向量,b_f 是遺忘門的偏置項,σ 是 sigmoid 函式。


輸入門的計算:

input

根據上一次的輸出和本次輸入來計算當前輸入的單元狀態:

當前輸入的單元狀態c_t

當前時刻的單元狀態 c_t 的計算:由上一次的單元狀態 c_t-1 按元素乘以遺忘門 f_t,再用當前輸入的單元狀態 c_t 按元素乘以輸入門 i_t,再將兩個積加和:
這樣,就可以把當前的記憶 c_t 和長期的記憶 c_t-1 組合在一起,形成了新的單元狀態 c_t
由於遺忘門的控制,它可以儲存很久很久之前的資訊,由於輸入門的控制,它又可以避免當前無關緊要的內容進入記憶。

當前時刻的單元狀態c_t

輸出門的計算:

output

LSTM 的反向傳播訓練演算法

主要有三步:

1. 前向計算每個神經元的輸出值,一共有 5 個變數,計算方法就是前一部分:

2. 反向計算每個神經元的誤差項值。與 RNN 一樣,LSTM 誤差項的反向傳播也是包括兩個方向:
一個是沿時間的反向傳播,即從當前 t 時刻開始,計算每個時刻的誤差項;
一個是將誤差項向上一層傳播。

3. 根據相應的誤差項,計算每個權重的梯度。


gate 的啟用函式定義為 sigmoid 函式,輸出的啟用函式為 tanh 函式,導數分別為:

具體推導公式為:

具體推導公式為:


目標是要學習 8 組引數,如下圖所示:

又權重矩陣 W 都是由兩個矩陣拼接而成,這兩部分在反向傳播中使用不同的公式,因此在後續的推導中,權重矩陣也要被寫為分開的兩個矩陣。

接著就來求兩個方向的誤差,和一個梯度計算。
這個公式推導過程在本文的學習資料中有比較詳細的介紹,大家可以去看原文:
https://zybuluo.com/hanbingtao/note/581764


1. 誤差項沿時間的反向傳遞:

定義 t 時刻的誤差項:

目的是要計算出 t-1 時刻的誤差項:

利用 h_t c_t 的定義,和全導數公式,可以得到 將誤差項向前傳遞到任意k時刻的公式:


2. 將誤差項傳遞到上一層的公式:


3. 權重梯度的計算:

以上就是 LSTM 的訓練演算法的全部公式。


關於它的 Tuning 有下面幾個建議:

來自 LSTM Hyperparameter Tuning:
https://deeplearning4j.org/lstm

還有一個用 LSTM 做 text_generation 的例子

https://github.com/fchollet/keras/blob/master/examples/lstm_text_generation.py

學習資料:
https://zybuluo.com/hanbingtao/note/581764