1. 程式人生 > 其它 >文字分類(一):使用Pytorch進行文字分類——BiLSTM+Attention

文字分類(一):使用Pytorch進行文字分類——BiLSTM+Attention

一、架構圖

二、程式碼

class TextBILSTM(nn.Module):
    
    def __init__(self,
                 config:TRNNConfig,
                 char_size = 5000,
                 pinyin_size = 5000):
        super(TextBILSTM, self).__init__()
        self.num_classes = config.num_classes
        self.learning_rate = config.learning_rate
        self.keep_dropout 
= config.keep_dropout self.char_embedding_size = config.char_embedding_size self.pinyin_embedding_size = config.pinyin_embedding_size self.l2_reg_lambda = config.l2_reg_lambda self.hidden_dims = config.hidden_dims self.char_size = char_size self.pinyin_size
= pinyin_size self.rnn_layers = config.rnn_layers self.build_model() def build_model(self): # 初始化字向量 self.char_embeddings = nn.Embedding(self.char_size, self.char_embedding_size) # 字向量參與更新 self.char_embeddings.weight.requires_grad = True
# 初始化拼音向量 self.pinyin_embeddings = nn.Embedding(self.pinyin_size, self.pinyin_embedding_size) self.pinyin_embeddings.weight.requires_grad = True # attention layer self.attention_layer = nn.Sequential( nn.Linear(self.hidden_dims, self.hidden_dims), nn.ReLU(inplace=True) ) # self.attention_weights = self.attention_weights.view(self.hidden_dims, 1) # 雙層lstm self.lstm_net = nn.LSTM(self.char_embedding_size, self.hidden_dims, num_layers=self.rnn_layers, dropout=self.keep_dropout, bidirectional=True) # FC層 # self.fc_out = nn.Linear(self.hidden_dims, self.num_classes) self.fc_out = nn.Sequential( nn.Dropout(self.keep_dropout), nn.Linear(self.hidden_dims, self.hidden_dims), nn.ReLU(inplace=True), nn.Dropout(self.keep_dropout), nn.Linear(self.hidden_dims, self.num_classes) ) def attention_net_with_w(self, lstm_out, lstm_hidden): ''' :param lstm_out: [batch_size, len_seq, n_hidden * 2] :param lstm_hidden: [batch_size, num_layers * num_directions, n_hidden] :return: [batch_size, n_hidden] ''' lstm_tmp_out = torch.chunk(lstm_out, 2, -1) # h [batch_size, time_step, hidden_dims] h = lstm_tmp_out[0] + lstm_tmp_out[1] # [batch_size, num_layers * num_directions, n_hidden] lstm_hidden = torch.sum(lstm_hidden, dim=1) # [batch_size, 1, n_hidden] lstm_hidden = lstm_hidden.unsqueeze(1) # atten_w [batch_size, 1, hidden_dims] atten_w = self.attention_layer(lstm_hidden) # m [batch_size, time_step, hidden_dims] m = nn.Tanh()(h) # atten_context [batch_size, 1, time_step] atten_context = torch.bmm(atten_w, m.transpose(1, 2)) # softmax_w [batch_size, 1, time_step] softmax_w = F.softmax(atten_context, dim=-1) # context [batch_size, 1, hidden_dims] context = torch.bmm(softmax_w, h) result = context.squeeze(1) return result def forward(self, char_id, pinyin_id): # char_id = torch.from_numpy(np.array(input[0])).long() # pinyin_id = torch.from_numpy(np.array(input[1])).long() sen_char_input = self.char_embeddings(char_id) sen_pinyin_input = self.pinyin_embeddings(pinyin_id) sen_input = torch.cat((sen_char_input, sen_pinyin_input), dim=1) # input : [len_seq, batch_size, embedding_dim] sen_input = sen_input.permute(1, 0, 2) output, (final_hidden_state, final_cell_state) = self.lstm_net(sen_input) # output : [batch_size, len_seq, n_hidden * 2] output = output.permute(1, 0, 2) # final_hidden_state : [batch_size, num_layers * num_directions, n_hidden] final_hidden_state = final_hidden_state.permute(1, 0, 2) # final_hidden_state = torch.mean(final_hidden_state, dim=0, keepdim=True) # atten_out = self.attention_net(output, final_hidden_state) atten_out = self.attention_net_with_w(output, final_hidden_state) return self.fc_out(atten_out)

三、解釋

1、將BILSTM網路輸出的結果(shape:[batch_size, time_step, hidden_dims * num_directions(=2)])拆成兩個大小為[batch_size, time_step, hidden_dims]的Tensor;
2、將第一步拆出的兩個Tensor進行相加運算得到h(shape:[batch_size, time_step, hidden_dims]);
3、將BILSTM網路最後一個隱狀態(shape:[batch_size, num_layers * num_directions, hidden_dims])在第二維度進行求和,得到新的lstm_hidden(shape:[batch_size, hidden_dims]);
4、將lstm_hidden的維度從[batch_size, n_hidden]擴充套件到[batch_size, 1, hidden_dims];
5、使用slef.atten_layer(h)獲得用於後續計算權重的向量atten_w(shape:[batch_size, 1, hidden_dims]);
6、將h進行tanh啟用,得到m(shape:[batch_size, time_step, hidden_dims]);
7、使用torch.bmm(atten_w, m.transpose(1, 2)) 得到atten_context(shape:[batch_size, 1, time_step]);
8、將atten_context使用F.softmax(atten_context, dim=-1)進行歸一化,得到基於上下文權重的softmax_w(shape:[batch_size, 1, time_step]);
9、使用torch.bmm(softmax_w, h)得到基於權重的BILSTM輸出context(shape:[batch_size, 1, hidden_dims]);
10、將context的第二維度消掉,得到result(shape:[batch_size, hidden_dims]) ;
11、返回result;

四、經驗值

模型效果
1層BILSTM在訓練集準確率:99.8%,測試集準確率:96.5%;
2層BILSTM在訓練集準確率:99.9%,測試集準確率:97.3%;
調參
dropout的值要在 0.1 以下(經驗之談,筆者在實踐中發現,dropout取0.1時比dropout取0.3時在測試集準確率能提高0.5%)。
https://blog.csdn.net/dendi_hust/article/details/94435919