On Tree-Based Neural Sentence Modeling閱讀筆記&原碼解讀
技術標籤:文獻閱讀
標題
On Tree-Based Neural Sentence Modeling
專案地址:
https://github.com/ExplorerFreda/TreeEnc
背景
- syntactic parsing trees(Socher et al., 2013; Tai et al., 2015; Zhu et al., 2015)
- trivial trees
- binary balanced tree
- left-branching tree
- right-branching tree
- Latent Trees(Choi et al., 2018;Yogatama et al., 2017; Maillard et al., 2017;Williams et al., 2018)
- reforcement learning (Williams, 1992)
- Gumbel Softmax (Jang et al., 2017; Maddison et al.,2017)
- Linear Structures
- LSTM(Hochreiter and Schmidhuber, 1997)
- GRU(Cho et al., 2014)
傳統的基於樹的句子編碼器,都是給定一個明確的先驗:句子的syntactic parsing structure。實驗表明,沒有明確的語法資訊的trivial tree也能獲得很好、甚至更好的效果。
本文探究在基於樹的句子建模中,究竟什麼因素是最重要的。
本篇部落格只考慮 text classification這一downstream task。
資料(以文字分類任務的dbpedia為例)
one_example = {“label”:…, “sentence”:…, “constituency_tree_encoding”:…}
首先,對於沒有parsing trees的資料集,使用ZPar來得到句子的binary parsing tree。
binary parsing tree如下圖所示:
即:
(I ((love (my (pet cat ))).))
接著,我們對這個二叉樹進行轉化,來得到句子的"constituency_tree_encoding"
轉化的程式碼在https://github.com/ExplorerFreda/TreeEnc/blob/master/data/README.md
大致思路是維護一個棧,一旦遇到右括號,則出棧中前兩個元素(stack[-2],stack[-1]),分別作為left和right,而stack[-3](是對應的括號),則作為子樹的parent。
再然後,從程式碼中可以看到,對於輸入的constituency_tree_encoding,可以得到其mask_ids。使用的程式碼如下。這部分的作用還不太懂。大致流程就是,每一步(一共sentence_length-1步),找到left、righ、parent,並把left和right替換成parent,並在mask_ids中加入left在當前句子(curr_sent)中的下標。
def tree_encoding_to_mask_ids(tree_encoding):
items = [int(x) for x in tree_encoding.strip().split(',')]
sentence_length = len(items) // 3 + 1 # (sentence_length - 1) * 3 = len(items)
curr_sent = [x for x in range(sentence_length)]
mask_ids = list()
assert 3 * (sentence_length - 1) == len(items)
for i in range(sentence_length - 1):
left_node = items[i * 3]
right_node = items[i * 3 + 1]
father_node = items[i * 3 + 2]
left_index = curr_sent.index(left_node)
assert curr_sent[left_index + 1] == right_node
curr_sent = curr_sent[:left_index] + [father_node] + curr_sent[left_index + 2:]
mask_ids.append(left_index)
return mask_ids
另外,把sentence->word_ids;統計length;得到上述mask_ids;把label也->idx,就得到了dataset
最後,為了得到dataloader,僅需要看看collate_fn是在幹什麼:
- sentences = torch.LongTensor(self.pad_sentence(words_batch))
- masks = self.make_one_hot_gold_mask(self.pad_mask(mask_ids_batch))
- padded = [d + [0] * (max_length - len(d)) for d in data]: 取batch中最長的sentence_length -1 ,然後補0
- self.make_one_hot_gold_mask()是真的不知道在幹嘛…
def collate(self, batch):
words_batch, raw_sentences_batch, length_batch, mask_ids_batch, label_batch = list(zip(*batch))
sentences = torch.LongTensor(self.pad_sentence(words_batch))
lengths = torch.LongTensor(length_batch)
try:
masks = self.make_one_hot_gold_mask(self.pad_mask(mask_ids_batch))
except TypeError:
masks = None
labels = torch.LongTensor(label_batch)
return {'sentences': sentences, 'lengths': lengths, 'masks': masks, 'labels': labels,
'raw_sentences': raw_sentences_batch}
模型
Parsing tree
略,見下。
Binary balanced trees
略,見下。
Left-branching trees
略,見下。
Right-branching trees
不略。啊,上述四種tree居然使用了同一個類:class RecursiveTreeLSTMEncoder(TreeLSTMEncoder)。這部分理解應該比較困難的吧。。畢竟我對於tree_masks都還沒有理解(明天結合例子在紙上跑一跑吧!打工人)
Gumbel trees
這個可能不打算用?再說吧。
LSTM
對應了程式碼中class LinearLSTMEncoder(nn.Module)這一個類。主要有以下細節
- att_weights, _ = pad_packed_sequence(att_weights, batch_first=True, padding_value=-1e8): 利用pack和pad來把padding位置的att_weights權重設定為-1e8是第一次見
- encodings = torch.cat((torch.cat(forward_encodings, dim=0),hiddens[:, 0, self.hidden_size:]), dim=1):雙向lstm中,取出前向和後向的拼接,這種方式也是第一次見,可能這種就是正確的方式。
就是(bi)lstm + pooling(max\mean\attn\或者取last)。
bi-leaf-RNN
word embedding --bileafrnn --> leaf node representations