1. 程式人生 > >lstm的數學推導

lstm的數學推導

本文是根據以下三篇文章整理的LSTM推導過程,公式都源於文章,只是一些比較概念性的東西,要coding的話還要自己去吃透以下文章。

前向傳播:

1、計算三個gate(in, out, forget)的輸入和cell的輸入:

zinj(t)=mwinjmym(t1)+v=1SjwinjcvjScvj(t1),(1)(1)zinj(t)=∑mwinjmym(t−1)+∑v=1SjwinjcjvScjv(t−1), zφj(t)=mwφjmym(t1)+v=1SjwφjcvjScvj(t1),(2)(2)zφj(t)=∑mwφjmym(t−1)+∑v=1SjwφjcjvScjv(t−1),
zoutj(t)=mwoutjmym(t1)+v=1SjwoutjcvjScvj(t1),(3)(3)zoutj(t)=∑mwoutjmym(t−1)+∑v=1SjwoutjcjvScjv(t−1), zctj(t)=mwctjmym(t1)+v=1SjwctjcvjScvj(t1),(4)(4)zcjt(t)=∑mwcjtmym(t−1)+∑v=1SjwcjtcjvScjv(t−1),

2、計算上述各個gate和cell的啟用值:

yinj(t)=finj(zinj(t)),(5)(5)yinj(t)=finj(zinj(t)),
yφj(t)=fφj(zφj(t)),(6)(6)yφj(t)=fφj(zφj(t)), youtj(t)=foutj(zoutj(t)),(7)(7)youtj(t)=foutj(zoutj(t)), Scvj(0)=0,Scvj(t)=yφj(t)Scvj(t1)+yinj(t)g(zcvj(t)),(8)(8)Scjv(0)=0,Scjv(t)=yφj(t)Scjv(t−1)+yinj(t)g(zcjv(t)), ycvj(t)=youtjScvj(t),(9)(9)ycjv(t)=youtjScjv(t),

3、假定該網路為一個標準的三層結構(如下圖所示),即一個輸入層,一個隱層和一個輸出層。則對於一個輸出單元,我們可以按下述的方式計算它的輸入和啟用值。其中m為所有與該輸出單元連線的單元(包括輸入層的和隱層的)。


zk(t)=mwkmym(t),(10)(10)zk(t)=∑mwkmym(t), yk(t)=fk(zk(t)),(11)(11)yk(t)=fk(zk(t)),

4、計算當前時間點對應狀態對input gate和、forget gate以及cell的偏導數。這裡跟CNN不一樣,CNN前向只是求值,沒有傳遞梯度。但對於lstm,由於內部狀態的改變依賴前一時間點的狀態,因此內部狀態的引數也會把錯誤傳遞到網路下一層,因此前向也涉及到梯度傳遞。