【NLP】Attention原理和原始碼解析
對attention一直停留在淺層的理解,看了幾篇介紹思想及原理的文章,也沒實踐過,今天立個Flag,一天深入原理和原始碼!如果你也是處於attention model level one的狀態,那不妨好好看一下啦。
內容:
- 核心思想
- 原理解析(圖解+公式)
- 模型分類
- 優缺點
- TF原始碼解析
P.S. 拒絕長篇大論,適合有基礎的同學快速深入attention,不明白的地方請留言諮詢~
1. 核心思想
Attention的思想理解起來比較容易,就是在decoding階段對input中的資訊賦予不同權重。在nlp中就是針對sequence的每個time step input,在cv中就是針對每個pixel。
2. 原理解析
針對Seq2seq翻譯來說,rnn-based model差不多是圖1的樣子:
而比較基礎的加入attention與rnn結合的model是下面的樣子(也叫soft attention):
其中 是 對應的權重,算出所有權重後會進行softmax和加權,得到 。
可以看到Encoding和decoding階段仍然是rnn,但是decoding階使用attention的輸出結果 作為rnn的輸入。
那麼重點來了, 權重 是怎麼來的呢?常見有三種方法:
思想就是根據當前解碼“狀態”判斷輸入序列的權重分佈。
如果把attention剝離出來去看的話,其實是以下的機制:
輸入是query(Q), key(K), value(V),輸出是attention value。如果與之前的模型對應起來的話,query就是 ,key就是 ,value也是。模型通過Q和K的匹配計算出權重,再結合V得到輸出:
再深入理解下去,這種機制其實做的是定址(addressing),也就是模仿中央處理器與儲存互動的方式將儲存的內容讀出來,可以看一下李巨集毅老師的課程。
3. 模型分類
3.1 Soft/Hard Attention
soft attention:傳統attention,可被嵌入到模型中去進行訓練並傳播梯度
hard attention:不計算所有輸出,依據概率對encoder的輸出取樣,在反向傳播時需採用蒙特卡洛進行梯度估計
3.2 Global/Local Attention
global attention:傳統attention,對所有encoder輸出進行計算
local attention:介於soft和hard之間,會預測一個位置並選取一個視窗進行計算
3.3 Self Attention
傳統attention是計算Q和K之間的依賴關係,而self attention則分別計算Q和K自身的依賴關係。具體的詳解會在下篇文章給出~
4. 優缺點
優點:
- 在輸出序列與輸入序列“順序”不同的情況下表現較好,如翻譯、閱讀理解
- 相比RNN可以編碼更長的序列資訊
缺點:
- 對序列順序不敏感
- 通常和RNN結合使用,不能並行化
5. TF原始碼解析
發現已經有人解析得很明白了,即使TF程式碼有更新,原理應該還是差不多的,直接放上來吧:
顧秀森:Tensorflow原始碼解讀(一):AttentionSeq2Seq模型 zhuanlan.zhihu.com【參考資料】: