NLP 相關演算法 LSTM 演算法流程
LSTM希望通過改進的RNN內部計算方法來應對普通RNN經常面臨的梯度消失和梯度爆炸。基本思路是通過改變逆向傳播求導時單純的偏導連乘關係,從而避免較小的sigmoid或relu啟用函式偏導連乘現象。
RNN網路unfold以後,將按時間t展開為若干個結構相同的計算單元,每個計算單元在利用當前時間的輸入以外,還需要之前時間的輸出。以下將展示每個計算單元的內部計算流程,假設當前的計算單元對應時間為t。
每個計算單元內由input gate,forget gate和output gate三個“閘門”結構依先後順序構成。在每一個gate內部,相關的輸入都匹配專門的權重矩陣,各個輸入相加後都將匹配專門的bias向量,總體求和後需要通過專門的啟用函式進行處理形成輸出。 設定當期(即t期)輸入為
,前一期輸出為
。
input gate
input gate實際上是類似於一個filter,即用sigmoid啟用函式的啟用值過濾或加權實際的input。實際的input為:
sigmoid啟用函式filter為:
input gate層的最終輸出就是
與
的點乘,即元素層面的對應相乘。
inner state
LSTM較於普通RNN網路增加了一個內部狀態量 . 記憶的控制就是通過forget gate對於 的過濾而發揮作用。
forget gate
與input gate相同,forget gate也是一個sigmoid啟用函式啟用值形成的filter,用於對上一期的狀態量
進行過濾。
當期的狀態量
就是input gate層的輸出值與IG過濾後的上一期狀態量的簡單相加的結果。注意這裡的操作僅為簡單的相加,並沒有加入權重,不存在相乘,也沒有使用新的啟用函式,這一步驟是消除RNN反向傳播網路梯度消失或梯度爆炸的關鍵:
output gate
同之前的兩個gate類似,output gate也是一個sigmoid啟用函式filter,對當期的狀態量
進行過濾。
在接受過濾前,先使用tanh啟用函式進行區間壓縮:
以此對壓縮後的
進行過濾,形成最終當期計算單元的最終輸出:
和
將可用於下一期(t+1)計算單元的內部計算。