1. 程式人生 > 其它 >Seq2seq模型的一個變種網路:Pointer Network的簡單介紹

Seq2seq模型的一個變種網路:Pointer Network的簡單介紹

Pointer Network(為方便起見以下稱為指標網路)是seq2seq模型的一個變種。他們不是把一個序列轉換成另一個序列, 而是產生一系列指向輸入序列元素的指標。最基礎的用法是對可變長度序列或集合的元素進行排序。

seq2seq的基礎是一個LSTM編碼器加上一個LSTM解碼器。在機器翻譯的語境中, 最常聽到的是: 用一種語言造句, 編碼器把它變成一個固定大小的陳述。解碼器將他轉換成一個句子, 可能和之前的句子長度不同。例如, “como estas?”-兩個單詞-將被翻譯成 “how are you?”-三個單詞。

當“注意力”增強時模型效果會更好。這意味著解碼器在輸入的前後都可以訪問。就是說, 它可以從每個步驟訪問編碼器狀態, 而不僅僅是最後一個。思考一下它怎樣幫助西班牙語讓形容詞在名詞之前: “neural network”變成 “red neuronal”

在專業術語中,“注意力”(至少是這種特定的 基於內容的注意力) 歸結為加權平均值均數。簡而言之,編碼器狀態的加權平均值轉換為解碼器狀態。注意力只是權重的分配。

想知道更多可以訪問:https://medium.com/datalogue/attention-in-keras-1892773a4f22

在指標網路中, 注意力更簡單:它不考慮輸入元素,而是在概率上指向它們。實際上,你得到了輸入的排列。有關更多細節和公式, 請參閱論文:

https://arxiv.org/abs/1506.03134

注意, 不需要使用所有的指標。例如, 給定一段文字, 網路可以通過指向兩個元素來標記摘錄: 它的起始位置和結束位置。

實驗

我們從順序數字開始?換句話說,一個深入的argsort:

In [3]: np.argsort([10,30,20 ])
Out[3]: array([0,2,1], dtype=int64)

In [4]: np.argsort([40,10,30,20 ])
Out[4]: array([1,3,2,0], dtype=int64)

令人驚訝的是,作者在論文中沒有繼續進行完成任務。相反的,他們使用兩個奇特的問題:旅行推銷員和凸包(參考README), 雖然結果是好的。但為什麼不按照數字順序呢?

原來,數字排序很難做到。他們在後續檔案中提到了這個問題(Order Matters: Sequence to sequence for sets)。重點是順序不能錯。也就是說,我們討論的是輸入元素的順序。作者發現,它對結果影響很大, 這不是我們想要的。因為本質上我們處理的是集合作為輸入, 而不是序列。集合沒有固定的順序,所以元素是如何排列在理論上不應該影響結果。

因此, 本文介紹了一種改進的架構, 它們通過連線到另一個LSTM的前饋網路來替換LSTM編碼器。這就是說,LSTM重複執行,以產生一個置換不變的嵌入給輸入。解碼器同樣是一個指標網路。

讓我們回到數字排列。較長的集合更難去排列。對於5個數字,他們報告的準確度範圍是81%-94%, 具體取決於模型 (這裡提到的準確度是指正確排序序列的百分比)。當處理15數字時, 這個範圍變成了0%-10%。

在我們的研究中,對於五個數字,我們幾乎達到了100%的準確度。請注意, 這是Keras所報告的 “分類精度”, 意思是在正確位置上元素的百分比。例如, 這個例子是50%準確度,即前兩個元素不動, 但最後兩個被調換:

4 3 2 1 ->3 2 0 1

對於有八元素的序列, 分類精度下降到大約33%。我們還嘗試了一個更具挑戰性的任務, 按它們的和對一個集合進行排序:

[1 2] [3 4] [2 3]->0 2 1

網路處理它就像處理簡單的(un)標量數字。

我們注意到的一個意想不到的事情是, 網路傾向於重複指標, 尤其是在訓練的早期。這是令人失望的:顯然它不記得它不久之前的預測。

y_test: [2 0 1 4 3]
p:      [2 2 2 2 2]

在訓練的早期, 人們聚集在一起, 構想指標網路的輸出。

y_test: [2 0 1 4 3]
p:      [2 0 2 4 3]

同時, 訓練有時會被某種準確度所困。而一個對少量數字進行訓練的網路並不能概括更大的, 比如:

981,66,673
856,10,438
884,808,241

為了幫助網路使用數字, 我們新增一個 ID (1,2, 3…) 到序列的每個元素。這個假設是因為注意力是基於內容的, 也許它可以使用內容中明確編碼的位置。此ID是一個數字 (train_with_positions) 或獨熱向量 (train_with_positions_categorical)。這看起來有點效果,但沒有解決根本問題。

實驗程式碼在GitHub可以使用。與original repo相比, 我們添加了一個數據生成指令碼, 並更改了訓練指令碼以從生成的檔案中載入資料。我們還將優化演算法改成RMSPro, 因為它在處理學習率的過程中似乎收斂得很好。

資料結構

3D陣列中的資料。第一個維度 (行) 是像往常一樣的例子。第二個維度“列”通常是特徵(屬性), 但帶序列的特徵進入第三個維度。第二個維度由給定序列的元素組成。下面是三個序列示例, 每個都有三個元素 (步驟), 每個元素都有兩個特徵:

array([[[8, 2],
        [3, 3],
        [10, 3]],

       [[1, 4],
        [19,12],
        [4,10]],

       [[19, 0],
        [15,12],
        [8, 6]],

目標是按特徵的和對元素進行排序, 因此相應的目標將是:

array([[1,0,2],
       [0,2,1],
       [2,0,1],

並且,它們將被明確編碼:

array([[[0., 1., 0.],
        [1., 0., 0.],
        [0., 0., 1.]],

       [[1., 0., 0.],
        [0., 0., 1.],
        [0., 1., 0.]],

       [[0., 0., 1.],
        [1., 0., 0.],
        [0., 1., 0.]],

這裡有一個問題,我們一直在討論迴圈網路如何處理可變長度的序列,但實際上資料是3D陣列,如上所示。換句話說,序列長度是固定的。

處理這一問題的方法是在最大可能的序列長度上固定維度, 並用零填充未使用的位置。

但它有可能搞亂代價函式,因此我們更好地掩蓋那些零, 確保他們在計算損失時被省略。Keras官方的做法似乎是embdedding layer。相關引數為mask_zero:

mask_zero: 無論輸入值0是否是一個特殊的 “padding” 值, 都應該被遮蔽掉。當使用可變長度輸入的迴圈層時這很有用。如果它為“True”,那麼模型中的所有後續層都需要支援掩蔽, 否則將引發異常。如果 mask_zero設定為True, 那麼作為一個序列,詞彙表中不能使用索引0(input_dim應等於詞彙量“+1”)。

關於實現

我們使用了一個Keras執行的指標網路。GitHub上還有一些其他的, 大部分用Tensorflow。

附錄A:指標網路的實現

  • https://github.com/keon/pointer-networks 幻燈片
  • https://github.com/devsisters/pointer-network-tensorflow
  • https://github.com/vshallc/PtrNets
  • https://github.com/ikostrikov/TensorFlow-Pointer-Networks
  • https://github.com/Chanlaw/pointer-networks
  • https://github.com/devnag/tensorflow-pointer-networks
  • https://github.com/udibr/pointer-generator
  • https://github.com/JerrikEph/SentenceOrdering_PTR
  • https://github.com/pradyu1993/seq2set-keras

附錄B:seq2seq的一些注意力的實現

  • https://github.com/philipperemy/keras-attention-mechanism
  • https://github.com/tensorflow/models/tree/master/textsum
  • https://github.com/tensorflow/tensor2tensor
  • Translation with a Sequence to Sequence Network and Attention (PyTorch tutorial)
  • https://github.com/MaximumEntropy/Seq2Seq-PyTorch
  • https://github.com/rowanz/pytorch-seq2seq
  • https://github.com/chainer/chainer/tree/seq2seq-europal/examples/seq2seq