LSTM-基本原理-前向傳播與反向傳播過程推導
前言
最近在實踐中用到LSTM模型,一直在查詢相關資料,推導其前向傳播、反向傳播過程。
LSTM有很多變體,查到的資料的描述也略有差別,且有一些地方讓我覺得有些困惑。目前查到的資料中我認為這個國外大神的部落格寫的比較清晰:
http://arunmallya.github.io/writeups/nn/lstm/index.html#/
這個部落格中的有些步驟有一定跳躍性,本文中的描述主要基於這篇部落格中的實現過程進行更細緻的推導,在此分享。本人能力有限,如果有不妥當之處,歡迎大家交流、指正。
LSTM演算法介紹
基本網路結構&前向傳播
先附上LSTM整體結構圖:
後面的所有推導都是基於這張圖的結構。
LSTM中比較重要的改進是加入了細胞狀態 : Ct代表t時刻的細胞狀態。
下面具體展開寫一下各部分結構:
輸入以及門計算部分
上述公式中表示了LSTM的四個基本結構:
輸入門、輸出門、遺忘門、用於更新細胞狀態的部分(圖中at 部分,這個好像沒有專門的名稱)
其中xt 為t時刻的輸入,我們設定其大小為: n X 1
注:此處x直接把偏置項涵蓋了,即加入了x0=1
ht 為t時刻的隱藏狀態,我們設定其大小為: d X 1
ct 為t時刻的細胞狀態,設定其大小為: d X 1
注:隱藏狀態和細胞狀態一般情況下維度一致。
W*: 維度大小:d X n
U* : 維度大小:d X d
注:W* 與 U*即為模型的引數矩陣,共八個,也就是說,訓練LSTM時訓練的就是這八個矩陣。
σ 代表sigmoid函式 , tanh 代表 tanh函式。
這裡的σ和 tanh是按元素操作,以σ為例,σ(X)即對向量X中的每個元素Xi分別計算
σ(Xi),σ(X)與 X 維度相同。
圖中 zt
細胞狀態更新部分
上圖中公式即為細胞狀態更新公式:
ct=it⊙at+ft⊙ct-1
其中⊙表示按元素乘(兩矩陣維度一致,相同位置元素相乘,結果矩陣維度不變)。
輸出部分
根據如下公式:
ht=ot⊙tanh(ct)
輸出t時刻隱藏狀態ht.
至此,LSTM的網路架構,前向傳播部分就梳理結束了。
反向傳播:梯度計算
為了理解反向傳播過程,先看一下前向傳播過程按時間將網路結構展開的效果圖:
可以看到,t時刻的細胞狀態ct對當前時刻隱藏狀態ht和下一時刻的細胞狀態ct+1都有貢獻。所以計算ct
注:此處預設的是多輸入-多輸出的情況,即每一時刻的隱藏狀態ht均參與損失函式計算。
對應反向傳播示意圖如下圖:
反向傳播的終極目標是為了計算梯度,更新引數,所以
要計算損失函式對W* 與 U*的偏導數。
下面一步步推導:
不考慮損失函式的形式,這裡我們泛化地設定:
然後將誤差逐層反向傳播。
注意:上圖中的δct的等式右端其實只是來自ht部分的梯度,下文計算來自ct+1的梯度,二者相加才是真正的δct。
上圖中以及後文會用到以下函式的導數計算公式:
注意,根據上圖中的
δct-1=δct⊙δft
我們可以得到:
δct=δct+1⊙δft+1
其實這部分只是來自ct+1的梯度。
綜合前兩張圖中關於δct的計算可得:
δct=δht⊙ot⊙(1-tanh2(ct))+δct+1⊙δft+1
好,下面我們繼續反向傳播:
這裡的寫法稍微有些跳躍,其實不同的W* 與 U*的偏導數計算類似,所以原部落格作者把他們整合在了一個表示式中。
這裡以Wc為例具體算一下:
類似地可以計算其他引數矩陣的梯度,
最終寫成圖中的整合矩陣形式:
δWt =δzt X (It)T
至此,我們求出了在t時刻,損失函式相對於各引數的梯度。
最終,根據上式累加不同時刻的梯度,進行引數更新。
小結
以上就是LSTM基本的網路結構,以及前向、反向傳播過程。由於時間和水平有限,文中可能有不妥當之處,歡迎大家批評指正,後續會不斷修改。