1. 程式人生 > >使用LSTM進行文字蘊含判斷

使用LSTM進行文字蘊含判斷

使用LSTM進行文字蘊含判斷


最近了解了一下什麼是文字蘊含,大概就是兩句話,如果能從前提句(premise)能推出假設句(hypothesis)或者這兩句話非常相似說的是同一個意思,那麼就是蘊含關係(entailment),否則就是矛盾關係(contradiction),如果看不出就是中立的(neutral)。所以說可以看成是一個三分類問題。有很多處理方法,其中一些通過某種方式對這兩句話表示成向量再進行匹配,但是一句話中並不是每個詞對於匹配來說都是同樣重要,比如兩句話的主語完全不同那麼這兩句話極有可能是矛盾的。於是,另一種方法採用一種詞對詞(word by word)的方式,對假設句採用從前往後的方式,每個詞和前提句比較。比較的時候利用注意力機制看看前提句中那些詞貢獻比較大,利用這種方式對假設句掃描一遍後再進行判斷。

模型

模型部分參照了這裡
實現了Learning Natural Language Inference with LSTM

h j s h_j^s h

k t h_k^t 分別為前提句和假設句中的詞 x j s x_j^s
x k t x_k^t 在對應LSTM的輸出。
a k t a_k^t 為對應於假設(hypothesis)中第k個詞 x k x_k 的attention向量,通過對前提句所有詞的LSTM輸出加權求和得到:
a k = j = 1 M α k j h j s a_k=\sum_{j=1}^{M}\alpha_{kj}h_j^s
其中, α k j \alpha_{kj} 為假設句中第k個詞對應於前提句中第j個詞的attention權重,可由下式計算:
α k j = e x p ( e k j ) j e x p ( e k j ) \alpha_{kj}=\frac{exp(e_{kj})}{\sum_{j'}exp(e_{kj'})}
e k j = w e t a n h ( W s h j s + W t h k t + W m h k 1 m ) e_{kj}=w^e\cdot tanh(W^sh_j^s+W^th_k^t+W^mh_{k-1}^m)
其中, w e d w_e\in \Re^d \cdot 為兩個向量的點乘, W W^* 為要訓練的權重, h k 1 m h_{k-1}^m 為第三個LSTM的輸出:
m k = [ a k , h k t ] m_k=[a_k, h_k^t]
h k m = L S T M ( m k , h k 1 m , c k 1 m ) h_k^m=LSTM(m_k, h_{k-1}^m, c_{k-1}^m)
其中, [ ] [\cdot ] 表示拼接操作。最後使用該LSTM最後的輸出狀態,使用全連線層去預測3種類別。

預處理及實現效果

資料集使用SNLI,預處理部分真的很討人厭,我用了torchtext簡化程式碼,像論文中一樣我用了與訓練的Glove詞向量,為了減少模型的引數,詞向量不進行訓練。論文中採用了一些技巧,比如對前提句後新增一個NULL關鍵字,對於沒有見過的詞采用其相鄰詞的詞向量求平均,這些我省略了。

效果並沒有論文中的那麼好,我試了一下只有0.8383,原因很可能是模型哪邊有問題或者沒有進行調參等等,也可能是訓練不夠充分,因為我只迭代了5個epoch,執行比較慢,時間長達2個小時。不過從attention視覺化可以看出模型確實進行了對齊操作。下圖分別是entailment,contradiction,neutral對應的注意力分佈圖,圖1可以看出dog和animal對應,frisbee和toy對應,snow則對應於cold weather;圖二dog和cat主語不同明顯是矛盾的;圖3則不容易看出來。這裡是我跑的程式碼。
entailmentcontradictionneutral