記憶網路RNN、LSTM與GRU
一般的神經網路輸入和輸出的維度大小都是固定的,針對序列型別(尤其是變長的序列)的輸入或輸出資料束手無策。RNN通過採用具有記憶的隱含層單元解決了序列資料的訓練問題。LSTM、GRU屬於RNN的改進,解決了RNN中梯度消失爆炸的問題,屬於序列資料訓練的常用方案。
RNN
結構
傳統的神經網路的輸入和輸出都是確定的,RNN的輸入和輸出都是不確定的sequence
資料。其結構如下:
具體地,RNN有隱含層,隱含層也是記憶層,其狀態(權值)會傳遞到下一個狀態中。
訓練
訓練步驟如下:
- 構建損失函式
- 求損失函式對權值的梯度
- 採用梯度下降法更新權值引數
關於損失函式,根據需要選擇構建即可,下面提供兩種常見的損失函式:
關於梯度下降,採用BPTT(Backpropagation through time)演算法,該演算法的核心是對每一個時間戳,計算該時間戳中權重的梯度,然後更新權重。需要注意的是,不同時間戳同樣權重的梯度可能是不一樣的,如下圖所示都減去,相當於更新同一塊記憶體區域中的權重。
應用
- 多對多:詞性標註pos tagging、語音識別、name entity recognition(區分poeple、organizations、places、information extration(區分place of departure、destination、time of departure、time of arrival, other)、機器翻譯
- 多對一:情感分析
- 一對多:caption generation
RNN Variants
RNN的變種大致包含下面3個思路:
- 增加隱含層的輸入引數:例如除了
ht−1,xt ,還可以包含yt−1 作為輸入。 - 增加隱含層的深度
- 雙向RNN
LSTM
結構
- 單個時間戳,RNN輸入1個x,輸出1個y
- 單個時間戳,LSTM輸入4個x,輸出1個y
相比RNN,LSTM的輸入多了3個x,對應3個gate,這3個gate分別是:
- input gate:控制輸入
- forget gate:控制cell
- output gate:控制輸出
涉及到的啟用函式共5個,其中3個控制gate的(通常用sigmoid函式,模擬gate的開閉狀態),1個作用於輸入上,一個作用於cell的輸出上。
LSTM單個時間戳的具體執行如下:
- 輸入:4個輸入
x ,1個cell的狀態c - 輸出:1個輸出
a ,1個更新的cell狀態c′
梯度消失及梯度爆炸
首先,要明白RNN中梯度消失與梯度爆炸的原因:在時間戳的更新中,cell的狀態不斷乘以
接下來,要明白LSTM如何解決RNN中梯度消失與爆炸的問題。
針對梯度消失,RNN中當獲取
總結來說,LSTM相比RNN,將
針對梯度爆炸,即使將
GRU
結構
GRU相比LSTM的3個gate,只用了兩個gate:
- update gate:
zt - reset gate: