1. 程式人生 > >LSTM簡介以及數學推導(FULL BPTT)

LSTM簡介以及數學推導(FULL BPTT)

了解 組織 線表 含義 計算公式 增加 以及 限制 意思

http://blog.csdn.net/a635661820/article/details/45390671

前段時間看了一些關於LSTM方面的論文,一直準備記錄一下學習過程的,因為其他事兒,一直拖到了現在,記憶又快模糊了。現在趕緊補上,本文的組織安排是這樣的:先介紹rnn的BPTT所存在的問題,然後介紹最初的LSTM結構,在介紹加了遺忘控制門的,然後是加了peephole connections結構的LSTM,都是按照真實提出的時間順序來寫的。本文相當於把各個論文核心部分簡要匯集一下而做的筆記,已提供快速的了解。

一.rnn結構的BPTT學習算法存在的問題

先看一下比較典型的BPTT一個展開的結構,如下圖,這裏只考慮了部分圖,因為其他部分不是這裏要討論的內容。

技術分享

對於t時刻的誤差信號計算如下:

技術分享

這樣權值的更新方式如下:

技術分享

上面的公式在BPTT中是非常常見的了,那麽如果這個誤差信號一直往過去傳呢,假設任意兩個節點u, v他們的關系是下面這樣的:

技術分享

那麽誤差傳遞信號的關系可以寫成如下的遞歸式:

技術分享

n表示圖中一層神經元的個數,這個遞歸式的大概含義不難理解,要求t-q時刻誤差信號對t時刻誤差信號的偏導,就先求出t-q+1時刻對t時刻的,然後把求出來的結果傳到t-q時刻,遞歸停止條件是q = 1時,就是剛開始寫的那部分計算公式了。將上面的遞歸式展開後可以得到:

技術分享

論文裏面說的是可以通過歸納來證明,我沒仔細推敲這裏了,把裏面連乘展開看容易明白一點:

技術分享

整個結果式對T求和的次數是n^(q-1), 即T有n^(q-1)項,那麽下面看問題出在哪兒。

如果|T| > 1, 誤差就會隨著q的增大而呈指數增長,那麽網絡的參數更新會引起非常大的震蕩。

如果|T| < 1, 誤差就會消失,導致學習無效,一般激活函數用simoid函數,它的倒數最大值是0.25, 權值最大值要小於4才能保證不會小於1。

誤差呈指數增長的現象比較少,誤差消失在BPTT中很常見。在原論文中還有更詳細的數學分析,但是了解到此個人覺的已經足夠理解問題所在了。

二.最初的LSTM結構

為了克服誤差消失的問題,需要做一些限制,先假設僅僅只有一個神經元與自己連接,簡圖如下:

技術分享

根據上面的,t時刻的誤差信號計算如下:

技術分享

為了使誤差不產生變化,可以強制令下式為1:

技術分享

根據這個式子,可以得到:

技術分享

這表示激活函數是線性的,常常的令fj(x) = x, wjj = 1.0,這樣就獲得常數誤差流了,也叫做CEC。

但是光是這樣是不行的,因為存在輸入輸出處權值更新的沖突(這裏原論文裏面的解釋我不是很明白),所以加上了兩道控制門,分別是input gate, output gate,來解決這個矛盾,圖如下:

技術分享

圖中增加了兩個控制門,所謂控制的意思就是計算cec的輸入之前,乘以input gate的輸出,計算cec的輸出時,將其結果乘以output gate的輸出,整個方框叫做block, 中間的小圓圈是CEC, 裏面是一條y = x的直線表示該神經元的激活函數是線性的,自連接的權重為1.0

三.增加forget gate

最初lstm結構的一個缺點就是cec的狀態值可能會一直增大下去,增加forget gate後,可以對cec的狀態進行控制,它的結構如下圖: 技術分享 這裏的相當於自連接權重不再是1.0,而是一個動態的值,這個動態值是forget gate的輸出值,它可以控制cec的狀態值,在必要時使之為0,即忘記作用,為1時和原來的結構一樣。

四.增加Peephole的LSTM結構

上面增加遺忘門一個缺點是當前CEC的狀態不能影響到input gate, forget gate在下一時刻的輸出,所以增加了Peephole connections。結構如下: 技術分享 這裏的gate的輸入部分就多加了一個來源了,forget gate, input gate的輸入來源增加了cec前一時刻的輸出,output gate的輸入來源增加了cec當前時刻的輸出,另外計算的順序也必須保證如下:
  1. input gate, forget gate的輸入輸出
  2. cell的輸入
  3. output gate的輸入輸出
  4. cell的輸出(這裏也是block的輸出)

五.一個LSTM的FULL BPTT推導(用誤差信號)

我記得當時看論文公式推導的時候很多地方比較難理解,最後隨便谷歌了幾下,找到一個寫的不錯的類似課件的PDF,但是已經不知道出處了,很容易就看懂LSTM的前向計算,誤差反傳更新了。把其中關於LSTM的部分放上來,首先網絡的完整結構圖如下: 技術分享 這個結構也是rwthlm源碼包中LSTM的結構,下面看一下公式的記號:
  • wij表示從神經元i到j的連接權重(註意這和很多論文的表示是反著的)
  • 神經元的輸入用a表示,輸出用b表示
  • 下標 ι, φ 和 ω分別表示input gate, forget gate,output gate
  • c下標表示cell,從cell到 input, forget和output gate的peephole權重分別記做 wcι , wcφ and wcω
  • Sc表示cell c的狀態
  • 控制門的激活函數用f表示,g,h分別表示cell的輸入輸出激活函數
  • I表示輸入層的神經元的個數,K是輸出層的神經元個數,H是隱層cell的個數
前向的計算: 技術分享 誤差反傳更新: 技術分享 技術分享

LSTM簡介以及數學推導(FULL BPTT)