tensorflow學習教程(十二)隨時間反向傳播BPTT
阿新 • • 發佈:2018-12-30
1、概述
上一節介紹了BP,這一節就簡單介紹一下BPTT。
2、網路結構
RNN正向傳播可以用上圖表示,這裡忽略偏置。
上圖中,
x(1:T)表示輸入序列,
y(1:T)表示輸出序列,
Y(1:T)表示標籤序列,
ht表示隱含層輸出,
st表示隱含層輸入,
zt表示經過啟用函式之前的輸出層輸出。
3、前向傳播
忽略偏置的前向傳播過程如下:
st=Uht-1+Wxt
ht=f(st)
zt=Vht
yt=f(zt)
其中,f是啟用函式。U、W、V三個權重在時間維度上是共享的。
每個時刻都有輸出,所以每個時刻都有損失,記t時刻的損失為Et,那麼對於樣本x(1:T)來說,
總損失,使用交叉熵做損失函式,則
3、反向傳播BPTT
跟BP類似,想求哪個權值對整體誤差的影響就用誤差對其求偏導。
3.1、E對V的梯度
根據鏈式法則有,
其中,
所以,
3.2、E對U的梯度
這個是BPTT與BP之所以不同的地方,因為不止t時刻隱含層與U有關,之前所有的隱含層都跟U有關。所以有,
其中,
假設
則
3、梯度爆炸和梯度消失
用鏈式法則求損失E對U的梯度為,
其中,
定義
則為,如果,則當 t-k→∞時,→∞,會造成系統不穩定,這就是所謂的梯度爆炸問題。相反,如果,則當 t-k→∞時,