LSTM模型與前向反向傳播演算法
在迴圈神經網路(RNN)模型與前向反向傳播演算法中,我們總結了對RNN模型做了總結。由於RNN也有梯度消失的問題,因此很難處理長序列的資料,大牛們對RNN做了改進,得到了RNN的特例LSTM(Long Short-Term Memory),它可以避免常規RNN的梯度消失,因此在工業界得到了廣泛的應用。下面我們就對LSTM模型做一個總結。
1. 從RNN到LSTM
在RNN模型裡,我們講到了RNN具有如下的結構,每個序列索引位置t都有一個隱藏狀態$h^{(t)}$。
如果我們略去每層都有的$o^{(t)}, L^{(t)}, y^{(t)}$,則RNN的模型可以簡化成如下圖的形式:
圖中可以很清晰看出在隱藏狀態$h^{(t)}$由$x^{(t)}$和$h^{(t-1)}$得到。得到$h^{(t)}$後一方面用於當前層的模型損失計算,另一方面用於計算下一層的$h^{(t+1)}$。
由於RNN梯度消失的問題,大牛們對於序列索引位置t的隱藏結構做了改進,可以說通過一些技巧讓隱藏結構複雜了起來,來避免梯度消失的問題,這樣的特殊RNN就是我們的LSTM。由於LSTM有很多的變種,這裡我們以最常見的LSTM為例講述。LSTM的結構如下圖:
可以看到LSTM的結構要比RNN的複雜的多,真佩服牛人們怎麼想出來這樣的結構,然後這樣居然就可以解決RNN梯度消失的問題?由於LSTM怎麼可以解決梯度消失是一個比較難講的問題,我也不是很熟悉,這裡就不多說,重點回到LSTM的模型本身。
2. LSTM模型結構剖析
上面我們給出了LSTM的模型結構,下面我們就一點點的剖析LSTM模型在每個序列索引位置t時刻的內部結構。
從上圖中可以看出,在每個序列索引位置t時刻向前傳播的除了和RNN一樣的隱藏狀態$h^{(t)}$,還多了另一個隱藏狀態,如圖中上面的長橫線。這個隱藏狀態我們一般稱為細胞狀態(Cell State),記為$C^{(t)}$。如下圖所示:
除了細胞狀態,LSTM圖中還有了很多奇怪的結構,這些結構一般稱之為門控結構(Gate)。LSTM在在每個序列索引位置t的門一般包括遺忘門,輸入門和輸出門三種。下面我們就來研究上圖中LSTM的遺忘門,輸入門和輸出門以及細胞狀態。
2.1 LSTM之遺忘門
遺忘門(forget gate)顧名思義,是控制是否遺忘的,在LSTM中即以一定的概率控制是否遺忘上一層的隱藏細胞狀態。遺忘門子結構如下圖所示:
圖中輸入的有上一序列的隱藏狀態$h^{(t-1)}$和本序列資料$x^{(t)}$,通過一個啟用函式,一般是sigmoid,得到遺忘門的輸出$f^{(t)}$。由於sigmoid的輸出$f^{(t)}$在[0,1]之間,因此這裡的輸出f^{(t)}代表了遺忘上一層隱藏細胞狀態的概率。用數學表示式即為:$$f^{(t)} = \sigma(W_fh^{(t-1)} + U_fx^{(t)} + b_f)$$
其中$W_f, U_f, b_f$為線性關係的係數和偏倚,和RNN中的類似。$\sigma$為sigmoid啟用函式。
2.2 LSTM之輸入門
輸入門(input gate)負責處理當前序列位置的輸入,它的子結構如下圖:
從圖中可以看到輸入門由兩部分組成,第一部分使用了sigmoid啟用函式,輸出為$i^{(t)}$,第二部分使用了tanh啟用函式,輸出為$a^{(t)}$, 兩者的結果後面會相乘再去更新細胞狀態。用數學表示式即為:$$i^{(t)} = \sigma(W_ih^{(t-1)} + U_ix^{(t)} + b_i)$$$$a^{(t)} =tanh(W_ah^{(t-1)} + U_ax^{(t)} + b_a)$$
其中$W_i, U_i, b_i, W_a, U_a, b_a,$為線性關係的係數和偏倚,和RNN中的類似。$\sigma$為sigmoid啟用函式。
2.3 LSTM之細胞狀態更新
在研究LSTM輸出門之前,我們要先看看LSTM之細胞狀態。前面的遺忘門和輸入門的結果都會作用於細胞狀態$C^{(t)}$。我們來看看從細胞狀態$C^{(t-1)}$如何得到$C^{(t)}$。如下圖所示:
細胞狀態$C^{(t)}$由兩部分組成,第一部分是$C^{(t-1)}$和遺忘門輸出$f^{(t)}$的乘積,第二部分是輸入門的$i^{(t)}$和$a^{(t)}$的乘積,即:$$C^{(t)} = C^{(t-1)} \odot f^{(t)} + i^{(t)} \odot a^{(t)}$$
其中,$\odot$為Hadamard積,在DNN中也用到過。
2.4 LSTM之輸出門
有了新的隱藏細胞狀態$C^{(t)}$,我們就可以來看輸出門了,子結構如下:
從圖中可以看出,隱藏狀態$h^{(t)}$的更新由兩部分組成,第一部分是$o^{(t)}$, 它由上一序列的隱藏狀態$h^{(t-1)}$和本序列資料$x^{(t)}$,以及啟用函式sigmoid得到,第二部分由隱藏狀態$C^{(t)}$和tanh啟用函式組成, 即:$$o^{(t)} = \sigma(W_oh^{(t-1)} + U_ox^{(t)} + b_o)$$$$h^{(t)} = o^{(t)} \odot tanh(C^{(t)})$$
通過本節的剖析,相信大家對於LSTM的模型結構已經有了解了。當然,有些LSTM的結構和上面的LSTM圖稍有不同,但是原理是完全一樣的。
3. LSTM前向傳播演算法
現在我們來總結下LSTM前向傳播演算法。LSTM模型有兩個隱藏狀態$h^{(t)}, C^{(t)}$,模型引數幾乎是RNN的4倍,因為現在多了$W_f, U_f, b_f, W_a, U_a, b_a, W_i, U_i, b_i, W_o, U_o, b_o$這些引數。
前向傳播過程在每個序列索引位置的過程為:
1)更新遺忘門輸出:$$f^{(t)} = \sigma(W_fh^{(t-1)} + U_fx^{(t)} + b_f)$$
2)更新輸入門兩部分輸出:$$i^{(t)} = \sigma(W_ih^{(t-1)} + U_ix^{(t)} + b_i)$$$$a^{(t)} = tanh(W_ah^{(t-1)} + U_ax^{(t)} + b_a)$$
3)更新細胞狀態:$$C^{(t)} = C^{(t-1)} \odot f^{(t)} + i^{(t)} \odot a^{(t)}$$
4)更新輸出門輸出:$$o^{(t)} = \sigma(W_oh^{(t-1)} + U_ox^{(t)} + b_o)$$$$h^{(t)} = o^{(t)} \odot tanh(C^{(t)})$$
5)更新當前序列索引預測輸出:$$\hat{y}^{(t)} = \sigma(Vh^{(t)} + c)$$
4. LSTM反向傳播演算法推導關鍵點
有了LSTM前向傳播演算法,推導反向傳播演算法就很容易了, 思路和RNN的反向傳播演算法思路一致,也是通過梯度下降法迭代更新我們所有的引數,關鍵點在於計算所有引數基於損失函式的偏導數。
在RNN中,為了反向傳播誤差,我們通過隱藏狀態$h^{(t)}$的梯度$\delta^{(t)}$一步步向前傳播。在LSTM這裡也類似。只不過我們這裡有兩個隱藏狀態$h^{(t)}$和$C^{(t)}$。這裡我們定義兩個$\delta$,即:$$\delta_h^{(t)} = \frac{\partial L}{\partial h^{(t)}}$$$$\delta_C^{(t)} = \frac{\partial L}{\partial C^{(t)}}$$
反向傳播時只使用了$\delta_C^{(t)}$,變數$\delta_h^{(t)}$僅為幫助我們在某一層計算用,並沒有參與反向傳播,這裡要注意。如下圖所示:
而在最後的序列索引位置$\tau$的$\delta_h^{(\tau)}$和 $\delta_C^{(\tau)} $為:$$\delta_h^{(\tau)} =\frac{\partial L}{\partial O^{(\tau)}} \frac{\partial O^{(\tau)}}{\partial h^{(\tau)}} = V^T(\hat{y}^{(\tau)} - y^{(\tau)})$$$$\delta_C^{(\tau)} =\frac{\partial L}{\partial h^{(\tau)}} \frac{\partial h^{(\tau)}}{\partial C^{(\tau)}} = \delta_h^{(\tau)} \odot o^{(\tau)} \odot (1 - tanh^2(C^{(\tau)}))$$
接著我們由$\delta_C^{(t+1)}$反向推導$\delta_C^{(t)}$。
$\delta_h^{(t)}$的梯度由本層的輸出梯度誤差決定,即:$$ \delta_h^{(t)} =\frac{\partial L}{\partial h^{(t)}} = V^T(\hat{y}^{(t)} - y^{(t)}) $$
而$\delta_C^{(t)}$的反向梯度誤差由前一層$\delta_C^{(t+1)}$的梯度誤差和本層的從$h^{(t)}$傳回來的梯度誤差兩部分組成,即:$$\delta_C^{(t)} =\frac{\partial L}{\partial C^{(t+1)}} \frac{\partial C^{(t+1)}}{\partial C^{(t)}} + \frac{\partial L}{\partial h^{(t)}}\frac{\partial h^{(t)}}{\partial C^{(t)}} = \delta_C^{(t+1)}\odot f^{(t+1)} + \delta_h^{(t)} \odot o^{(t)} \odot (1 - tanh^2(C^{(t)}))$$
有了$\delta_h^{(t)}$和$\delta_C^{(t)}$, 計算這一大堆引數的梯度就很容易了,這裡只給出$W_f$的梯度計算過程,其他的$U_f, b_f, W_a, U_a, b_a, W_i, U_i, b_i, W_o, U_o, b_o,V, c$的梯度大家只要照搬就可以了。$$\frac{\partial L}{\partial W_f} = \sum\limits_{t=1}^{\tau}\frac{\partial L}{\partial C^{(t)}} \frac{\partial C^{(t)}}{\partial f^{(t)}} \frac{\partial f^{(t)}}{\partial W_f} =\sum\limits_{t=1}^{\tau} \delta_C^{(t)} \odot C^{(t-1)} \odot f^{(t)}\odot(1-f^{(t)}) (h^{(t-1)})^T$$
5. LSTM小結
LSTM雖然結構複雜,但是隻要理順了裡面的各個部分和之間的關係,進而理解前向反向傳播演算法是不難的。當然實際應用中LSTM的難點不在前向反向傳播演算法,這些有演算法庫幫你搞定,模型結構和一大堆引數的調參才是讓人頭痛的問題。不過,理解LSTM模型結構仍然是高效使用的前提。
(歡迎轉載,轉載請註明出處。歡迎溝通交流: [email protected])
參考資料:
2) Deep Learning, book by Ian Goodfellow, Yoshua Bengio, and Aaron Courville