1. 程式人生 > >TensorFlow RNN深度學習 BiLSTM+CRF 實現 sequence labeling 序列標註

TensorFlow RNN深度學習 BiLSTM+CRF 實現 sequence labeling 序列標註

在TensorFlow RNN 深度學習下 BiLSTM+CRF 實現 sequence labeling 

雙向LSTM+CRF 序列標註問題

原始碼

去年底樣子一直在做NLP相關task,是個關於序列標註問題。這 sequence labeling屬於NLP的經典問題了,開始嘗試用HMM,哦不,用CRF做baseline,by the way, 用的CRF++。

關於CRF的理論就不再囉嗦了,街貨。順便提下,CRF比HMM在理論上以及實際效果上都要好不少。但我要說的是CRF跑我這task還是不太樂觀。P值0.6樣子,R低的離譜,所以F1很不樂觀。mentor告訴我說是特徵不足,師兄說是這個task本身就比較難做,F1低算是正常了。

CRF做完baseline後,一直在著手用BiLSTM+CRF跑 sequence labeling,奈何專案繁多,沒有多餘的精力去按照正常的計劃做出來。後來還是一點一點的,按照大牛們的步驟以及參考現有的程式碼,把 BiLSTM+CRF的實現拿下了。後來發現,跑出來的效果也不太理想……可能是這個task確實變態……抑或模型還要加強吧~

這裡對比下CRF與LSTM的cell,先說RNN吧,RNN其實是比CNN更適合做序列問題的模型,RNN隱層當前時刻的輸入有一部分是前一時刻的隱層輸出,這使得他能通過迴圈反饋連線看到前面的資訊,將一段序列的前面的context capture 過來參與此刻的計算,並且還具備非線性的擬合能力,這都是CRF無法超越的地方。而LSTM的cell很好的將RNN的梯度彌散問題優化解決了,他對門衛gate說:老兄,有的不太重要的資訊,你該忘掉就忘掉吧,免得佔用現在的資源。而雙向LSTM就更厲害了,不僅看得到過去,還能將未來的序列考慮進來,使得上下文資訊充分被利用。而CRF,他不像LSTM能夠考慮長遠的上下文資訊,它更多地考慮整個句子的區域性特徵的線性加權組合(通過特徵模板掃描整個句子),特別的一點,他計算的是聯合概率,優化了整個序列,而不是拼接每個時刻的最優值。那麼,將BILSTM與CRF一起就構成了還比較不錯的組合,這目前也是學術界的流行做法~

另外針對目前的跑通結果提幾個改進點:

1.+CNN,通過CNN的卷積操作去提取英文單詞的字母細節。

2.+char representation,作用與上相似,提取更細粒度的細節。

3.more joint model to go.

fine,叨了不少。codes time:

requirements:

ubuntu14

python2.7

tensorflow 0.8

numpy

pandas0.15

BILSTM_CRF.py

  1. import math
  2. import helper
  3. import numpy as np
  4. import tensorflow as tf
  5. from tensorflow.models.rnn import
    rnn, rnn_cell
  6. class BILSTM_CRF(object):
  7. def __init__(self, num_chars, num_classes, num_steps=200, num_epochs=100, embedding_matrix=None, is_training=True, is_crf=True, weight=False):
  8. # Parameter
  9. self.max_f1 = 0
  10. self.learning_rate = 0.002
  11. self.dropout_rate = 0.5
  12. self.batch_size = 128
  13. self.num_layers = 1
  14. self.emb_dim = 100
  15. self.hidden_dim = 100
  16. self.num_epochs = num_epochs
  17. self.num_steps = num_steps
  18. self.num_chars = num_chars
  19. self.num_classes = num_classes
  20. # placeholder of x, y and weight
  21. self.inputs = tf.placeholder(tf.int32, [None, self.num_steps])
  22. self.targets = tf.placeholder(tf.int32, [None, self.num_steps])
  23. self.targets_weight = tf.placeholder(tf.float32, [None, self.num_steps])
  24. self.targets_transition = tf.placeholder(tf.int32, [None])
  25. # char embedding
  26. if embedding_matrix != None:
  27. self.embedding = tf.Variable(embedding_matrix, trainable=False, name="emb", dtype=tf.float32)
  28. else:
  29. self.embedding = tf.get_variable("emb", [self.num_chars, self.emb_dim])
  30. self.inputs_emb = tf.nn.embedding_lookup(self.embedding, self.inputs)
  31. self.inputs_emb = tf.transpose(self.inputs_emb, [1, 0, 2])
  32. self.inputs_emb = tf.reshape(self.inputs_emb, [-1, self.emb_dim])
  33. self.inputs_emb = tf.split(0, self.num_steps, self.inputs_emb)
  34. # lstm cell
  35. lstm_cell_fw = tf.nn.rnn_cell.BasicLSTMCell(self.hidden_dim)
  36. lstm_cell_bw = tf.nn.rnn_cell.BasicLSTMCell(self.hidden_dim)
  37. # dropout
  38. if is_training:
  39. lstm_cell_fw = tf.nn.rnn_cell.DropoutWrapper(lstm_cell_fw, output_keep_prob=(1 - self.dropout_rate))
  40. lstm_cell_bw = tf.nn.rnn_cell.DropoutWrapper(lstm_cell_bw, output_keep_prob=(1 - self.dropout_rate))
  41. lstm_cell_fw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_fw] * self.num_layers)
  42. lstm_cell_bw = tf.nn.rnn_cell.MultiRNNCell([lstm_cell_bw] * self.num_layers)
  43. # get the length of each sample
  44. self.length = tf.reduce_sum(tf.sign(self.inputs), reduction_indices=1)
  45. self.length = tf.cast(self.length, tf.int32)
  46. # forward and backward
  47. self.outputs, _, _ = rnn.bidirectional_rnn(
  48. lstm_cell_fw,
  49. lstm_cell_bw,
  50. self.inputs_emb,
  51. dtype=tf.float32,
  52. sequence_length=self.length
  53. )
  54. # softmax
  55. self.outputs = tf.reshape(tf.concat(1, self.outputs), [-1, self.hidden_dim * 2])
  56. self.softmax_w = tf.get_variable("softmax_w", [self.hidden_dim * 2, self.num_classes])
  57. self.softmax_b = tf.get_variable("softmax_b", [self.num_classes])
  58. self.logits = tf.matmul(self.outputs, self.softmax_w) + self.softmax_b
  59. if not is_crf:
  60. pass
  61. else:
  62. self.tags_scores = tf.reshape(self.logits, [self.batch_size, self.num_steps, self.num_classes])
  63. self.transitions = tf.get_variable("transitions", [self.num_classes + 1, self.num_classes + 1])
  64. dummy_val = -1000
  65. class_pad = tf.Variable(dummy_val * np.ones((self.batch_size, self.num_steps, 1)), dtype=tf.float32)
  66. self.observations = tf.concat(2, [self.tags_scores, class_pad])
  67. begin_vec = tf.Variable(np.array([[dummy_val] * self.num_classes + [0] for _ in range(self.batch_size)]), trainable=False, dtype=tf.float32)
  68. end_vec = tf.Variable(np.array([[0] + [dummy_val] * self.num_classes for _ in range(self.batch_size)]), trainable=False, dtype=tf.float32)
  69. begin_vec = tf.reshape(begin_vec, [self.batch_size, 1, self.num_classes + 1])
  70. end_vec = tf.reshape(end_vec, [self.batch_size, 1, self.num_classes + 1])
  71. self.observations = tf.concat(1, [begin_vec, self.observations, end_vec])
  72. self.mask = tf.cast(tf.reshape(tf.sign(self.targets),[self.batch_size * self.num_steps]), tf.float32)
  73. # point score
  74. self.point_score = tf.gather(tf.reshape(self.tags_scores, [-1]), tf.range(0, self.batch_size * self.num_steps) * self.num_classes + tf.reshape(self.targets,[self.batch_size * self.num_steps]))
  75. self.point_score *= self.mask
  76. # transition score
  77. self.trans_score = tf.gather(tf.reshape(self.transitions, [-1]), self.targets_transition)
  78. # real score
  79. self.target_path_score = tf.reduce_sum(self.point_score) + tf.reduce_sum(self.trans_score)
  80. # all path score
  81. self.total_path_score, self.max_scores, self.max_scores_pre = self.forward(self.observations, self.transitions, self.length)
  82. # loss
  83. self.loss = - (self.target_path_score - self.total_path_score)
  84. # summary
  85. self.train_summary = tf.scalar_summary("loss", self.loss)
  86. self.val_summary = tf.scalar_summary("loss", self.loss)
  87. self.optimizer = tf.train.AdamOptimizer(learning_rate=self.learning_rate).minimize(self.loss)
  88. def logsumexp(self, x, axis=None):
  89. x_max = tf.reduce_max(x, reduction_indices=axis, keep_dims=True)
  90. x_max_ = tf.reduce_max(x, reduction_indices=axis)
  91. return x_max_ + tf.log(tf.reduce_sum(tf.exp(x - x_max), reduction_indices=axis))
  92. def forward(self, observations, transitions, length, is_viterbi=True, return_best_seq=True):
  93. length = tf.reshape(length, [self.batch_size])
  94. transitions = tf.reshape(tf.concat(0, [transitions] * self.batch_size), [self.batch_size, 6, 6])
  95. observations = tf.reshape(observations, [self.batch_size, self.num_steps + 2, 6, 1])
  96. observations = tf.transpose(observations, [1, 0, 2, 3])
  97. previous = observations[0, :, :, :]
  98. max_scores = []
  99. max_scores_pre = []
  100. alphas = [previous]
  101. for t in range(1, self.num_steps + 2):
  102. previous = tf.reshape(previous, [self.batch_size, 6, 1])
  103. current = tf.reshape(observations[t, :, :, :], [self.batch_size, 1, 6])
  104. alpha_t = previous + current + transitions
  105. if is_viterbi:
  106. max_scores.append(tf.reduce_max(alpha_t, reduction_indices=1))
  107. max_scores_pre.append(tf.argmax(alpha_t, dimension=1))
  108. alpha_t = tf.reshape(self.logsumexp(alpha_t, axis=1), [self.batch_size, 6, 1])
  109. alphas.append(alpha_t)
  110. previous = alpha_t
  111. alphas = tf.reshape(tf.concat(0, alphas), [self.num_steps + 2, self.batch_size, 6, 1])
  112. alphas = tf.transpose(alphas, [1, 0, 2, 3])
  113. alphas = tf.reshape(alphas, [self.batch_size * (self.num_steps + 2), 6, 1])
  114. last_alphas = tf.gather(alphas, tf.range(0, self.batch_size) * (self.num_steps + 2) + length)
  115. last_alphas = tf.reshape(last_alphas, [self.batch_size, 6, 1])
  116. max_scores = tf.reshape(tf.concat(0, max_scores), (self.num_steps + 1, self.batch_size, 6))
  117. max_scores_pre = tf.reshape(tf.concat(0, max_scores_pre), (self.num_steps + 1, self.batch_size, 6))
  118. max_scores = tf.transpose(max_scores, [1, 0, 2])
  119. max_scores_pre = tf.transpose(max_scores_pre, [1, 0, 2])
  120. return tf.reduce_sum(self.logsumexp(last_alphas, axis=1)), max_scores, max_scores_pre
  121. def train(self, sess, save_file, X_train, y_train, X_val, y_val):
  122. saver = tf.train.Saver()
  123. char2id, id2char = helper.loadMap("char2id")
  124. label2id, id2label = helper.loadMap("label2id")
  125. merged = tf.merge_all_summaries()
  126. summary_writer_train = tf.train.SummaryWriter('loss_log/train_loss', sess.graph)
  127. summary_writer_val = tf.train.SummaryWriter('loss_log/val_loss', sess.graph)
  128. num_iterations = int(math.ceil(1.0 * len(X_train) / self.batch_size))
  129. cnt = 0
  130. for epoch in range(self.num_epochs):
  131. # shuffle train in each epoch
  132. sh_index = np.arange(len(X_train))
  133. np.random.shuffle(sh_index)
  134. X_train = X_train[sh_index]
  135. y_train = y_train[sh_index]
  136. print "current epoch: %d" % (epoch)
  137. for iteration in range(num_iterations):
  138. # train
  139. X_train_batch, y_train_batch = helper.nextBatch(X_train, y_train, start_index=iteration * self.batch_size, batch_size=self.batch_size)
  140. y_train_weight_batch = 1 + np.array((y_train_batch == label2id['B']) | (y_train_batch == label2id['E']), float)
  141. transition_batch = helper.getTransition(y_train_batch)
  142. _, loss_train, max_scores, max_scores_pre, length, train_summary =\
  143. sess.run([
  144. self.optimizer,
  145. self.loss,
  146. self.max_scores,
  147. self.max_scores_pre,
  148. self.length,
  149. self.train_summary
  150. ],
  151. feed_dict={
  152. self.targets_transition:transition_batch,
  153. self.inputs:X_train_batch,
  154. self.targets:y_train_batch,
  155. self.targets_weight:y_train_weight_batch
  156. })
  157. predicts_train = self.viterbi(max_scores, max_scores_pre, length, predict_size=self.batch_size)
  158. if iteration % 10 == 0:
  159. cnt += 1
  160. precision_train, recall_train, f1_train = self.evaluate(X_train_batch, y_train_batch, predicts_train, id2char, id2label)
  161. summary_writer_train.add_summary(train_summary, cnt)
  162. print "iteration: %5d, train loss: %5d, train precision: %.5f, train recall: %.5f, train f1: %.5f" % (iteration, loss_train, precision_train, recall_train, f1_train)
  163. # validation
  164. if iteration % 100 == 0:
  165. X_val_batch, y_val_batch = helper.nextRandomBatch(X_val, y_val, batch_size=self.batch_size)
  166. y_val_weight_batch = 1 + np.array((y_val_batch == label2id['B']) | (y_val_batch == label2id['E']), float)
  167. transition_batch = helper.getTransition(y_val_batch)
  168. loss_val, max_scores, max_scores_pre, length, val_summary =\
  169. sess.run([
  170. self.loss,
  171. self.max_scores,
  172. self.max_scores_pre,
  173. self.length,
  174. self.val_summary
  175. ],
  176. feed_dict={
  177. self.targets_transition:transition_batch,
  178. self.inputs:X_val_batch,
  179. self.targets:y_val_batch,
  180. self.targets_weight:y_val_weight_batch
  181. })
  182. predicts_val = self.viterbi(max_scores, max_scores_pre, length, predict_size=self.batch_size)
  183. precision_val, recall_val, f1_val = self.evaluate(X_val_batch, y_val_batch, predicts_val, id2char, id2label)
  184. summary_writer_val.add_summary(val_summary, cnt)
  185. print "iteration: %5d, valid loss: %5d, valid precision: %.5f, valid recall: %.5f, valid f1: %.5f" % (iteration, loss_val, precision_val, recall_val, f1_val)
  186. if f1_val > self.max_f1:
  187. self.max_f1 = f1_val
  188. save_path = saver.save(sess, save_file)
  189. print "saved the best model with f1: %.5f" % (self.max_f1)
  190. def test(self, sess, X_test, X_test_str, output_path):
  191. char2id, id2char = helper.loadMap("char2id")
  192. label2id, id2label = helper.loadMap("label2id")
  193. num_iterations = int(math.ceil(1.0 * len(X_test) / self.batch_size))
  194. print "number of iteration: " + str(num_iterations)
  195. with open(output_path, "wb") as outfile:
  196. for i in range(num_iterations):
  197. print "iteration: " + str(i + 1)
  198. results = []
  199. X_test_batch = X_test[i * self.batch_size : (i + 1) * self.batch_size]
  200. X_test_str_batch = X_test_str[i * self.batch_size : (i + 1) * self.batch_size]
  201. if i == num_iterations - 1 and len(X_test_batch) < self.batch_size:
  202. X_test_batch = list(X_test_batch)
  203. X_test_str_batch = list(X_test_str_batch)
  204. last_size = len(X_test_batch)
  205. X_test_batch += [[0 for j in range(self.num_steps)] for i in range(self.batch_size - last_size)]
  206. X_test_str_batch += [['x' for j in range(self.num_steps)] for i in range(self.batch_size - last_size)]
  207. X_test_batch = np.array(X_test_batch)
  208. X_test_str_batch = np.array(X_test_str_batch)
  209. results = self.predictBatch(sess, X_test_batch, X_test_str_batch, id2label)
  210. results = results[:last_size]
  211. else:
  212. X_test_batch = np.array(X_test_batch)
  213. results = self.predictBatch(sess, X_test_batch, X_test_str_batch, id2label)
  214. for i in range(len(results)):
  215. doc = ''.join(X_test_str_batch[i])
  216. outfile.write(doc + "<@>" +" ".join(results[i]).encode("utf-8") + "\n")
  217. def viterbi(self, max_scores, max_scores_pre, length, predict_size=128):
  218. best_paths = []
  219. for m in range(predict_size):
  220. path = []
  221. last_max_node = np.argmax(max_scores[m][length[m]])
  222. # last_max_node = 0
  223. for t in range(1, length[m] + 1)[::-1]:
  224. last_max_node = max_scores_pre[m][t][last_max_node]
  225. path.append(last_max_node)
  226. path = path[::-1]
  227. best_paths.append(path)
  228. return best_paths
  229. def predictBatch(self, sess, X, X_str, id2label):
  230. results = []
  231. length, max_scores, max_scores_pre = sess.run([self.length, self.max_scores, self.max_scores_pre], feed_dict={self.inputs:X})
  232. predicts = self.viterbi(max_scores, max_scores_pre, length, self.batch_size)
  233. for i in range(len(predicts)):
  234. x = ''.join(X_str[i]).decode("utf-8")
  235. y_pred = ''.join([id2label[val] for val in predicts[i] if val != 5 and val != 0])
  236. entitys = helper.extractEntity(x, y_pred)
  237. results.append(entitys)
  238. return results
  239. def evaluate(self, X, y_true, y_pred, id2char, id2label):
  240. precision = -1.0
  241. recall = -1.0
  242. f1 = -1.0
  243. hit_num = 0
  244. pred_num = 0
  245. true_num = 0
  246. for i in range(len(y_true)):
  247. x = ''.join([str(id2char[val].encode("utf-8")) for val i