1. 程式人生 > >真正的完全圖解Seq2Seq Attention模型

真正的完全圖解Seq2Seq Attention模型

作者:盛源車

知乎專欄:魔法抓的學習筆記

五分鐘看懂seq2seq attention模型。

本文通過圖片,詳細地畫出了seq2seq+attention模型的全部流程,幫助小夥伴們無痛理解機器翻譯等任務的重要模型。

seq2seq 是一個Encoder–Decoder 結構的網路,它的輸入一個序列,輸出也是一個序列, Encoder 中將一個可變長度的訊號序列變為固定長度的向量表達,Decoder 將這個固定長度的向量變成可變長度的目標的訊號序列。--簡書

好了別管了,接下來開始刷圖吧。

大框架

想象一下翻譯任務,input是一段英文,output是一段中文。

公式(直接跳過看圖最佳)

輸入: x = (x_1,...,x_{T_x})

輸出: y = (y_1,...,y_{T_y})

(1) h_t = RNN_{enc}(x_t, h_{t-1}) , Encoder方面接受的是每一個單詞word embedding,和上一個時間點的hidden state。輸出的是這個時間點的hidden state。

(2) s_t = RNN_{dec}(\hat{y_{t-1}},s_{t-1}) , Decoder方面接受的是目標句子裡單詞的word embedding,和上一個時間點的hidden state。

(3) c_i = \sum_{j=1}^{T_x} \alpha_{ij}h_j , context vector是一個對於encoder輸出的hidden states的一個加權平均。

(4) \alpha_{ij} = \frac{exp(e_{ij})}{\sum_{k=1}^{T_x}exp(e_{ik})} , 每一個encoder的hidden states對應的權重。

(5) e_{ij} = score(s_i, h_j) , 通過decoder的hidden states加上encoder的hidden states來計算一個分數,用於計算權重(4)

(6) \hat{s_t} = tanh(W_c[c_t;s_t]), 將context vector 和 decoder的hidden states 串起來。

(7) p(y_t|y_{<t},x) = softmax(W_s\hat{s_t}) ,計算最後的輸出概率。

詳細圖

左側為Encoder+輸入,右側為Decoder+輸出。中間為Attention。

(1) h_t = RNN_{enc}(x_t, h_{t-1}) , Encoder方面接受的是每一個單詞word embedding,和上一個時間點的hidden state。輸出的是這個時間點的hidden state。

從左邊Encoder開始,輸入轉換為word embedding, 進入LSTM。LSTM會在每一個時間點上輸出hidden states。如圖中的h1,h2,...,h8。

(2) s_t = RNN_{dec}(\hat{y_{t-1}},s_{t-1}) , Decoder方面接受的是目標句子裡單詞的word embedding,和上一個時間點的hidden state。

接下來進入右側Decoder,輸入為(1) 句首 &amp;amp;amp;amp;amp;amp;amp;lt;sos&amp;amp;amp;amp;amp;amp;amp;gt;符號,原始context vector(為0),以及從encoder最後一個hidden state: h8。LSTM的是輸出是一個hidden state。(當然還有cell state,這裡沒用到,不提。)

(3) c_i = \sum_{j=1}^{T_x} \alpha_{ij}h_j , context vector是一個對於encoder輸出的hidden states的一個加權平均。

(4) \alpha_{ij} = \frac{exp(e_{ij})}{\sum_{k=1}^{T_x}exp(e_{ik})} , 每一個encoder的hidden states對應的權重。

(5) e_{ij} = score(s_i, h_j) , 通過decoder的hidden states加上encoder的hidden states來計算一個分數,用於計算權重(4)

Decoder的hidden state與Encoder所有的hidden states作為輸入,放入Attention模組開始計算一個context vector。之後會介紹attention的計算方法。

下一個時間點

來到時間點2,之前的context vector可以作為輸入和目標的單詞串起來作為lstm的輸入。之後又回到一個hiddn state。以此迴圈。

(6) \hat{s_t} = tanh(W_c[c_t;s_t]), 將context vector 和 decoder的hidden states 串起來。

(7) p(y_t|y_{<t},x) = softmax(W_s\hat{s_t}) ,計算最後的輸出概率。

另一方面,context vector和decoder的hidden state合起來通過一系列非線性轉換以及softmax最後計算出概率。

在luong中提到了三種score的計算方法。這裡圖解前兩種:

Attention score function: dot

輸入是encoder的所有hidden states H: 大小為(hid dim, sequence length)。decoder在一個時間點上的hidden state, s: 大小為(hid dim, 1)。

第一步:旋轉H為(sequence length, hid dim) 與s做點乘得到一個 大小為(sequence length, 1)的分數

第二步:對分數做softmax得到一個合為1的權重

第三步:將H與第二步得到的權重做點乘得到一個大小為(hid dim, 1)的context vector

Attention score function: general

輸入是encoder的所有hidden states H: 大小為(hid dim1, sequence length)。decoder在一個時間點上的hidden state, s: 大小為(hid dim2, 1)。此處兩個hidden state的緯度並不一樣。

第一步:旋轉H為(sequence length, hid dim1) 與 Wa [大小為 hid dim1, hid dim 2)] 做點乘, 再和s做點乘得到一個 大小為(sequence length, 1)的分數

第二步:對分數做softmax得到一個合為1的權重

第三步:將H與第二步得到的權重做點乘得到一個大小為(hid dim, 1)的context vector

完結