1. 程式人生 > 實用技巧 >NLP經典模型入門-seq2seq

NLP經典模型入門-seq2seq

前言:筆者之前是cv方向,因為工作原因需要學習NLP相關的模型,因此特意梳理一下關於NLP的幾個經典模型,由於有基礎,這一系列不會關注基礎內容或者公式推導,而是更側重對整體原理的理解。順便推薦兩個很不錯的github專案——開箱即用的中文教程以及演算法更全但是有些跑不通的英文教程

一. 從encode和decode說起

encode和decode是一個非常常見的結構。encode可以理解為從輸入得到特徵的過程,而decode可以理解為從特徵得到結果的過程。同樣適用https://zhuanlan.zhihu.com/p/28054589的RNN網路結構圖來說明。

實際上,我們利用RNN由輸入得到h1~h4這4個狀態的過程就是encode,同理包括從image經過大量cnn得到feature等。而decode則是從特徵得到結果,比如在影象分類問題中,這一步驟可能是對feature map做avgpool然後經過若干個fc和softmax得到類別概率。而在之前的RNN中,decode過程則是直接取最後一個狀態h4,然後經過fc和softmax得到類別概率。

二. seq2seq與attention

原理

seq2seq指的是輸入和輸出都是序列的問題,比如翻譯。在這種時候,顯然是不能用之前RNN的思路的。但是實際上,兩者並沒有太大的差異。比如,我們完全可以使用相同的encode過程,用一個RNN來提取特徵,但是我們每個\(h_i\)都經過fc和softmax,求出每個位置對應的最可能是什麼詞,這就是一個最簡單的翻譯網路了。當然直觀上輸入和輸出必須是等長的,如果不等長,可以取兩者最大公共長度補齊;或者也有其它方法,詳情參考最上面給的部落格。

當然也有很多不同的操作,比如你可以給h1~h4做一個特徵融合。然後用融合後的特徵去decode,這裡的c就代表從encode得到的最終用於decode的特徵。這裡我們可以把之前介紹的attention機制湧過來。

decode特徵也可以是多個,比如現在是一個4對3的任務,有4個狀態特徵h1-h4,則可以用不同的注意力機制權重得到c1-c3,這裡就是之前提到的翻譯例子了:

其它

Pytorch中,RNN和LSTM模組的輸入和輸出必須是等長的;意思是如果你輸入是x個狀態,那麼輸出也必然是x個狀態。所以如果輸入和輸出是不等長的,需要自己填補到等長,Pytorch也提供了相關函式填補函式pad_sequence以及對應的解壓函式。如果需要使用的可以自己瞭解。