1. 程式人生 > 其它 >On Tree-Based Neural Sentence Modeling閱讀筆記&原碼解讀

On Tree-Based Neural Sentence Modeling閱讀筆記&原碼解讀

技術標籤:文獻閱讀

標題

On Tree-Based Neural Sentence Modeling
專案地址:
https://github.com/ExplorerFreda/TreeEnc

背景

  • syntactic parsing trees(Socher et al., 2013; Tai et al., 2015; Zhu et al., 2015)
  • trivial trees
    • binary balanced tree
    • left-branching tree
    • right-branching tree
  • Latent Trees(Choi et al., 2018;Yogatama et al., 2017; Maillard et al., 2017;Williams et al., 2018)
    • reforcement learning (Williams, 1992)
    • Gumbel Softmax (Jang et al., 2017; Maddison et al.,2017)
  • Linear Structures
    • LSTM(Hochreiter and Schmidhuber, 1997)
    • GRU(Cho et al., 2014)

傳統的基於樹的句子編碼器,都是給定一個明確的先驗:句子的syntactic parsing structure。實驗表明,沒有明確的語法資訊的trivial tree也能獲得很好、甚至更好的效果。
本文探究在基於樹的句子建模中,究竟什麼因素是最重要的。
本篇部落格只考慮 text classification這一downstream task。

資料(以文字分類任務的dbpedia為例)

one_example = {“label”:…, “sentence”:…, “constituency_tree_encoding”:…}
首先,對於沒有parsing trees的資料集,使用ZPar來得到句子的binary parsing tree。
binary parsing tree如下圖所示:
在這裡插入圖片描述
即:

(I ((love (my (pet cat ))).))

接著,我們對這個二叉樹進行轉化,來得到句子的"constituency_tree_encoding"
轉化的程式碼在https://github.com/ExplorerFreda/TreeEnc/blob/master/data/README.md

中給出。
大致思路是維護一個棧,一旦遇到右括號,則出棧中前兩個元素(stack[-2],stack[-1]),分別作為left和right,而stack[-3](是對應的括號),則作為子樹的parent。

再然後,從程式碼中可以看到,對於輸入的constituency_tree_encoding,可以得到其mask_ids。使用的程式碼如下。這部分的作用還不太懂。大致流程就是,每一步(一共sentence_length-1步),找到left、righ、parent,並把left和right替換成parent,並在mask_ids中加入left在當前句子(curr_sent)中的下標。

    def tree_encoding_to_mask_ids(tree_encoding):
        items = [int(x) for x in tree_encoding.strip().split(',')]
        sentence_length = len(items) // 3 + 1 # (sentence_length - 1) * 3 = len(items)
        curr_sent = [x for x in range(sentence_length)]
        mask_ids = list()
        assert 3 * (sentence_length - 1) == len(items)
        for i in range(sentence_length - 1):
            left_node = items[i * 3]
            right_node = items[i * 3 + 1]
            father_node = items[i * 3 + 2]
            left_index = curr_sent.index(left_node)
            assert curr_sent[left_index + 1] == right_node
            curr_sent = curr_sent[:left_index] + [father_node] + curr_sent[left_index + 2:]
            mask_ids.append(left_index)
        return mask_ids

另外,把sentence->word_ids;統計length;得到上述mask_ids;把label也->idx,就得到了dataset
最後,為了得到dataloader,僅需要看看collate_fn是在幹什麼:

  • sentences = torch.LongTensor(self.pad_sentence(words_batch))
  • masks = self.make_one_hot_gold_mask(self.pad_mask(mask_ids_batch))
    • padded = [d + [0] * (max_length - len(d)) for d in data]: 取batch中最長的sentence_length -1 ,然後補0
    • self.make_one_hot_gold_mask()是真的不知道在幹嘛…
def collate(self, batch):
    words_batch, raw_sentences_batch, length_batch, mask_ids_batch, label_batch = list(zip(*batch))
    sentences = torch.LongTensor(self.pad_sentence(words_batch))
    lengths = torch.LongTensor(length_batch)
    try:
        masks = self.make_one_hot_gold_mask(self.pad_mask(mask_ids_batch))
    except TypeError:
        masks = None
    labels = torch.LongTensor(label_batch)
    return {'sentences': sentences, 'lengths': lengths, 'masks': masks, 'labels': labels,
           'raw_sentences': raw_sentences_batch}

模型

Parsing tree

略,見下。

Binary balanced trees

略,見下。

Left-branching trees

略,見下。

Right-branching trees

不略。啊,上述四種tree居然使用了同一個類:class RecursiveTreeLSTMEncoder(TreeLSTMEncoder)。這部分理解應該比較困難的吧。。畢竟我對於tree_masks都還沒有理解(明天結合例子在紙上跑一跑吧!打工人)

Gumbel trees

這個可能不打算用?再說吧。

LSTM

對應了程式碼中class LinearLSTMEncoder(nn.Module)這一個類。主要有以下細節

  • att_weights, _ = pad_packed_sequence(att_weights, batch_first=True, padding_value=-1e8): 利用pack和pad來把padding位置的att_weights權重設定為-1e8是第一次見
  • encodings = torch.cat((torch.cat(forward_encodings, dim=0),hiddens[:, 0, self.hidden_size:]), dim=1):雙向lstm中,取出前向和後向的拼接,這種方式也是第一次見,可能這種就是正確的方式。

就是(bi)lstm + pooling(max\mean\attn\或者取last)。


bi-leaf-RNN

word embedding --bileafrnn --> leaf node representations

pooling