BILSTM+CRF實現命名實體識別NER
阿新 • • 發佈:2018-12-13
#第一步:資料處理 #pikle是一個將任意複雜的物件轉成物件的文字或二進位制表示的過程。 #同樣,必須能夠將物件經過序列化後的形式恢復到原有的物件。 #在 Python 中,這種序列化過程稱為 pickle, #可以將物件 pickle 成字串、磁碟上的檔案或者任何類似於檔案的物件, #也可以將這些字串、檔案或任何類似於檔案的物件 unpickle 成原來的物件。 import sys, pickle, os, random import numpy as np ## tags, BIO tag2label = {"O": 0, "B-PER": 1, "I-PER": 2, "B-LOC": 3, "I-LOC": 4, "B-ORG": 5, "I-ORG": 6 } #輸入train_data檔案的路徑,讀取訓練集的語料,輸出train_data def read_corpus(corpus_path): """ read corpus and return the list of samples :param corpus_path: :return: data """ data = [] with open(corpus_path, encoding='utf-8') as fr: '''lines的形狀為['北\tB-LOC\n','京\tI-LOC\n','的\tO\n','...']總共有2220537個字及對應的tag''' lines = fr.readlines() sent_, tag_ = [], [] for line in lines: if line != '\n': #char 與 label之間有個空格 #ine.strip()的意思是去掉每句話句首句尾的空格 #.split()的意思是根據空格來把整句話切割成一片片獨立的字串放到陣列中,同時刪除句子中的換行符號\n [char, label] = line.strip().split() #把一個個的字放進sent_ sent_.append(char) #把字後面的tag放進tag_ tag_.append(label) else: data.append((sent_, tag_)) sent_, tag_ = [], [] """ data的形狀為[(['我',在'北','京'],['O','O','B-LOC','I-LOC'])...第一句話 (['我',在'天','安','門'],['O','O','B-LOC','I-LOC','I-LOC'])...第二句話 ( 第三句話 ) ] 總共有50658句話""" return data #由train_data來構造一個(統計非重複字)字典{'第一個字':[對應的id,該字出現的次數],'第二個字':[對應的id,該字出現的次數], , ,} #去除低頻詞,生成一個word_id的字典並儲存在輸入的vocab_path的路徑下, #儲存的方法是pickle模組自帶的dump方法,儲存後的檔案格式是word2id.pkl檔案 def vocab_build(vocab_path, corpus_path, min_count): """ :param vocab_path: :param corpus_path: :param min_count: :return: """ data = read_corpus(corpus_path) word2id = {} #sent_的形狀為['我',在'北','京'],對應的tag_為['O','O','B-LOC','I-LOC'] for sent_, tag_ in data: for word in sent_: #如果字串只包含數字則返回 True 否則返回 False。 if word.isdigit(): word = '<NUM>' #A-Z:(\u0041-\u005a) a-z :\u0061-\u007a elif ('\u0041' <= word <='\u005a') or ('\u0061' <= word <='\u007a'): word = '<ENG>' if word not in word2id: #[len(word2id)+1, 1]用來統計[位置標籤,出現次數],第一次出現定為1 word2id[word] = [len(word2id)+1, 1] else: #word2id[word][1]實現對詞頻的統計,出現次數累加1 word2id[word][1] += 1 #用來統計低頻詞 low_freq_words = [] for word, [word_id, word_freq] in word2id.items(): #尋找低於某個數字的低頻詞 if word_freq < min_count and word != '<NUM>' and word != '<ENG>': low_freq_words.append(word) for word in low_freq_words: #把這些低頻詞從字典中刪除 del word2id[word] #刪除低頻詞後為每個字重新建立id,而不再統計詞頻 new_id = 1 for word in word2id.keys(): word2id[word] = new_id new_id += 1 word2id['<UNK>'] = new_id word2id['<PAD>'] = 0 print(len(word2id)) with open(vocab_path, 'wb') as fw: # 序列化到名字為word2id.pkl檔案 pickle.dump(word2id, fw) #通過pickle模組自帶的load方法(反序列化方法)載入輸出word2id def read_dictionary(vocab_path): """ :param vocab_path: :return: """ vocab_path = os.path.join(vocab_path) with open(vocab_path, 'rb') as fr: #反序列化方法載入輸出 word2id = pickle.load(fr) print('vocab_size:', len(word2id)) return word2id '''word2id的形狀為{'當': 1, '希': 2, '望': 3, '工': 4, '程': 5,。。'<UNK>': 3904, '<PAD>': 0} 總共3903個字''' #輸入一句話,生成一個 sentence_id '''sentence_id的形狀為[1,2,3,4,...]對應的sent為['當','希','望','工',程'...]''' def sentence2id(sent, word2id): """ :param sent: :param word2id: :return: """ sentence_id = [] for word in sent: if word.isdigit(): word = '<NUM>' elif ('\u0041' <= word <= '\u005a') or ('\u0061' <= word <= '\u007a'): word = '<ENG>' #如果sent中的詞在word2id找不到,用<UNK>--->3905來表示 if word not in word2id: word = '<UNK>' sentence_id.append(word2id[word]) return sentence_id #輸入vocab,vocab就是前面得到的word2id,embedding_dim=300 def random_embedding(vocab, embedding_dim): """ :param vocab: :param embedding_dim: :return: """ #返回一個len(vocab)*embedding_dim=3905*300的矩陣(每個字投射到300維)作為初始值 embedding_mat = np.random.uniform(-0.25, 0.25, (len(vocab), embedding_dim)) embedding_mat = np.float32(embedding_mat) return embedding_mat #padding,輸入一句話,不夠標準的樣本用pad_mark來補齊 ''' 輸入:seqs的形狀為二維矩陣,形狀為[[33,12,17,88,50]-第一句話 [52,19,14,48,66,31,89]-第二句話 ] 輸出:seq_list為seqs經過padding後的序列 seq_len_list保留了padding之前每條樣本的真實長度 seq_list和seq_len_list用來餵給feed_dict ''' def pad_sequences(sequences, pad_mark=0): ''' :param sequences: :param pad_mark: :return: ''' #返回一個序列中長度最長的那條樣本的長度 max_len = max(map(lambda x : len(x), sequences)) seq_list, seq_len_list = [], [] for seq in sequences: #由元組格式()轉化為列表格式[] seq = list(seq) #不夠最大長度的樣本用0補上放到列表seq_list seq_ = seq[:max_len] + [pad_mark] * max(max_len - len(seq), 0) seq_list.append(seq_) #seq_len_list用來統計每個樣本的真實長度 seq_len_list.append(min(len(seq), max_len)) return seq_list, seq_len_list #生成batch ''' seqs的形狀為二維矩陣,形狀為[[33,12,17,88,50....]...第一句話 [52,19,14,48,66....]...第二句話 ] labels的形狀為二維矩陣,形狀為[[0, 0, 3, 4]....第一句話 [0, 0, 3, 4]...第二句話 ] ''' def batch_yield(data, batch_size, vocab, tag2label, shuffle=False): """ :param data: :param batch_size: :param vocab: :param tag2label: :param shuffle: :return: """ if shuffle: random.shuffle(data) seqs, labels = [], [] for (sent_, tag_) in data: #sent_的形狀為[33,12,17,88,50....]句中的字在Wordid對應的位置標籤 #如果tag_形狀為['O','O','B-LOC','I-LOC'],對應的label_形狀為[0, 0, 3, 4] #返回tag2label字典中每個tag對應的value值 sent_ = sentence2id(sent_, vocab) label_ = [tag2label[tag] for tag in tag_] #保證了seqs的長度為batch_size if len(seqs) == batch_size: yield seqs, labels seqs, labels = [], [] seqs.append(sent_) labels.append(label_) if len(seqs) != 0: yield seqs, labels
#第二步:設定模型 import numpy as np import os, time, sys import tensorflow as tf from tensorflow.contrib.rnn import LSTMCell from tensorflow.contrib.crf import crf_log_likelihood from tensorflow.contrib.crf import viterbi_decode from data import pad_sequences, batch_yield from utils import get_logger from eval import conlleval class BiLSTM_CRF(object): def __init__(self, args, embeddings, tag2label, vocab, paths, config): self.batch_size = args.batch_size self.epoch_num = args.epoch self.hidden_dim = args.hidden_dim self.embeddings = embeddings self.CRF = args.CRF self.update_embedding = args.update_embedding self.dropout_keep_prob = args.dropout self.optimizer = args.optimizer self.lr = args.lr self.clip_grad = args.clip self.tag2label = tag2label self.num_tags = len(tag2label) self.vocab = vocab self.shuffle = args.shuffle self.model_path = paths['model_path'] self.summary_path = paths['summary_path'] self.logger = get_logger(paths['log_path']) self.result_path = paths['result_path'] self.config = config def build_graph(self): self.add_placeholders() self.lookup_layer_op() self.biLSTM_layer_op() self.softmax_pred_op() self.loss_op() self.trainstep_op() self.init_op() def add_placeholders(self): self.word_ids = tf.placeholder(tf.int32, shape=[None, None], name="word_ids") self.labels = tf.placeholder(tf.int32, shape=[None, None], name="labels") self.sequence_lengths = tf.placeholder(tf.int32, shape=[None], name="sequence_lengths") self.dropout_pl = tf.placeholder(dtype=tf.float32, shape=[], name="dropout") self.lr_pl = tf.placeholder(dtype=tf.float32, shape=[], name="lr") def lookup_layer_op(self): with tf.variable_scope("words"): _word_embeddings = tf.Variable(self.embeddings, dtype=tf.float32, trainable=self.update_embedding, name="_word_embeddings") word_embeddings = tf.nn.embedding_lookup(params=_word_embeddings, ids=self.word_ids, name="word_embeddings") self.word_embeddings = tf.nn.dropout(word_embeddings, self.dropout_pl) def biLSTM_layer_op(self): with tf.variable_scope("bi-lstm"): cell_fw = LSTMCell(self.hidden_dim) cell_bw = LSTMCell(self.hidden_dim) (output_fw_seq, output_bw_seq), _ = tf.nn.bidirectional_dynamic_rnn( cell_fw=cell_fw, cell_bw=cell_bw, inputs=self.word_embeddings, sequence_length=self.sequence_lengths, dtype=tf.float32) #維持行數不變,後面的行接到前面的行後面 output = tf.concat([output_fw_seq, output_bw_seq], axis=-1) #經過droupput處理 output = tf.nn.dropout(output, self.dropout_pl) with tf.variable_scope("proj"): W = tf.get_variable(name="W", shape=[2 * self.hidden_dim, self.num_tags], #該函式返回一個用於初始化權重的初始化程式 “Xavier” 。 #這個初始化器是用來保持每一層的梯度大小都差不多相同 initializer=tf.contrib.layers.xavier_initializer(), dtype=tf.float32) b = tf.get_variable(name="b", shape=[self.num_tags], #tf.zeros_initializer(),也可以簡寫為tf.Zeros() initializer=tf.zeros_initializer(), dtype=tf.float32) #output的形狀為[batch_size,steps,cell_num] s = tf.shape(output) #reshape的目的是為了跟w做矩陣乘法 output = tf.reshape(output, [-1, 2*self.hidden_dim]) pred = tf.matmul(output, W) + b #s[1]=batch_size self.logits = tf.reshape(pred, [-1, s[1], self.num_tags]) def loss_op(self): if self.CRF: #crf_log_likelihood作為損失函式 #inputs:unary potentials,就是每個標籤的預測概率值 #tag_indices,這個就是真實的標籤序列了 #sequence_lengths,一個樣本真實的序列長度,為了對齊長度會做些padding,但是可以把真實的長度放到這個引數裡 #transition_params,轉移概率,可以沒有,沒有的話這個函式也會算出來 #輸出:log_likelihood:標量;transition_params,轉移概率,如果輸入沒輸,它就自己算個給返回 log_likelihood, self.transition_params = crf_log_likelihood(inputs=self.logits, tag_indices=self.labels, sequence_lengths=self.sequence_lengths) self.loss = -tf.reduce_mean(log_likelihood) else: #交叉熵做損失函式 losses = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=self.logits, labels=self.labels) mask = tf.sequence_mask(self.sequence_lengths) losses = tf.boolean_mask(losses, mask) self.loss = tf.reduce_mean(losses) #新增標量統計結果 tf.summary.scalar("loss", self.loss) def softmax_pred_op(self): if not self.CRF: self.labels_softmax_ = tf.argmax(self.logits, axis=-1) self.labels_softmax_ = tf.cast(self.labels_softmax_, tf.int32) def trainstep_op(self): with tf.variable_scope("train_step"): self.global_step = tf.Variable(0, name="global_step", trainable=False) if self.optimizer == 'Adam': optim = tf.train.AdamOptimizer(learning_rate=self.lr_pl) elif self.optimizer == 'Adadelta': optim = tf.train.AdadeltaOptimizer(learning_rate=self.lr_pl) elif self.optimizer == 'Adagrad': optim = tf.train.AdagradOptimizer(learning_rate=self.lr_pl) elif self.optimizer == 'RMSProp': optim = tf.train.RMSPropOptimizer(learning_rate=self.lr_pl) elif self.optimizer == 'Momentum': optim = tf.train.MomentumOptimizer(learning_rate=self.lr_pl, momentum=0.9) elif self.optimizer == 'SGD': optim = tf.train.GradientDescentOptimizer(learning_rate=self.lr_pl) else: optim = tf.train.GradientDescentOptimizer(learning_rate=self.lr_pl) grads_and_vars = optim.compute_gradients(self.loss) grads_and_vars_clip = [[tf.clip_by_value(g, -self.clip_grad, self.clip_grad), v] for g, v in grads_and_vars] self.train_op = optim.apply_gradients(grads_and_vars_clip, global_step=self.global_step) def init_op(self): self.init_op = tf.global_variables_initializer() def add_summary(self, sess): """ :param sess: :return: """ self.merged = tf.summary.merge_all() self.file_writer = tf.summary.FileWriter(self.summary_path, sess.graph) def train(self, train, dev): """ :param train: :param dev: :return: """ saver = tf.train.Saver(tf.global_variables()) with tf.Session(config=self.config) as sess: sess.run(self.init_op) self.add_summary(sess) #epoch_num=40 for epoch in range(self.epoch_num): self.run_one_epoch(sess, train, dev, self.tag2label, epoch, saver) def test(self, test): saver = tf.train.Saver() with tf.Session(config=self.config) as sess: self.logger.info('=========== testing ===========') saver.restore(sess, self.model_path) label_list, seq_len_list = self.dev_one_epoch(sess, test) self.evaluate(label_list, seq_len_list, test) def demo_one(self, sess, sent): """ :param sess: :param sent: :return: """ label_list = [] for seqs, labels in batch_yield(sent, self.batch_size, self.vocab, self.tag2label, shuffle=False): label_list_, _ = self.predict_one_batch(sess, seqs) label_list.extend(label_list_) label2tag = {} for tag, label in self.tag2label.items(): label2tag[label] = tag if label != 0 else label tag = [label2tag[label] for label in label_list[0]] return tag def run_one_epoch(self, sess, train, dev, tag2label, epoch, saver): """ :param sess: :param train: :param dev: :param tag2label: :param epoch: :param saver: :return: """ #計算出多少個batch,計算過程:(50658+64-1)//64=792 num_batches = (len(train) + self.batch_size - 1) // self.batch_size #記錄開始訓練的時間 start_time = time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) #產生每一個batch batches = batch_yield(train, self.batch_size, self.vocab, self.tag2label, shuffle=self.shuffle) for step, (seqs, labels) in enumerate(batches): #sys.stdout 是標準輸出檔案,write就是往這個檔案寫資料 sys.stdout.write(' processing: {} batch / {} batches.'.format(step + 1, num_batches) + '\r') #step_num=epoch*792+step+1 step_num = epoch * num_batches + step + 1 feed_dict, _ = self.get_feed_dict(seqs, labels, self.lr, self.dropout_keep_prob) _, loss_train, summary, step_num_ = sess.run([self.train_op, self.loss, self.merged, self.global_step], feed_dict=feed_dict) if step + 1 == 1 or (step + 1) % 300 == 0 or step + 1 == num_batches: self.logger.info( '{} epoch {}, step {}, loss: {:.4}, global_step: {}'.format(start_time, epoch + 1, step + 1, loss_train, step_num)) self.file_writer.add_summary(summary, step_num) if step + 1 == num_batches: #訓練的最後一個batch儲存模型 saver.save(sess, self.model_path, global_step=step_num) self.logger.info('===========validation / test===========') label_list_dev, seq_len_list_dev = self.dev_one_epoch(sess, dev) self.evaluate(label_list_dev, seq_len_list_dev, dev, epoch) def get_feed_dict(self, seqs, labels=None, lr=None, dropout=None): """ :param seqs: :param labels: :param lr: :param dropout: :return: feed_dict """ #seq_len_list用來統計每個樣本的真實長度 #word_ids就是seq_list,padding後的樣本序列 word_ids, seq_len_list = pad_sequences(seqs, pad_mark=0) feed_dict = {self.word_ids: word_ids, self.sequence_lengths: seq_len_list} if labels is not None: #labels經過padding後,餵給feed_dict labels_, _ = pad_sequences(labels, pad_mark=0) feed_dict[self.labels] = labels_ if lr is not None: feed_dict[self.lr_pl] = lr if dropout is not None: feed_dict[self.dropout_pl] = dropout #seq_len_list用來統計每個樣本的真實長度 return feed_dict, seq_len_list def dev_one_epoch(self, sess, dev): """ :param sess: :param dev: :return: """ label_list, seq_len_list = [], [] for seqs, labels in batch_yield(dev, self.batch_size, self.vocab, self.tag2label, shuffle=False): label_list_, seq_len_list_ = self.predict_one_batch(sess, seqs) label_list.extend(label_list_) seq_len_list.extend(seq_len_list_) return label_list, seq_len_list def predict_one_batch(self, sess, seqs): """ :param sess: :param seqs: :return: label_list seq_len_list """ #seq_len_list用來統計每個樣本的真實長度 feed_dict, seq_len_list = self.get_feed_dict(seqs, dropout=1.0) if self.CRF: #transition_params代表轉移概率,由crf_log_likelihood方法計算出 logits, transition_params = sess.run([self.logits, self.transition_params], feed_dict=feed_dict) label_list = [] # 打包成元素形式為元組的列表[(logit,seq_len),(logit,seq_len),( ,),] for logit, seq_len in zip(logits, seq_len_list): viterbi_seq, _ = viterbi_decode(logit[:seq_len], transition_params) label_list.append(viterbi_seq) return label_list, seq_len_list else: label_list = sess.run(self.labels_softmax_, feed_dict=feed_dict) return label_list, seq_len_list def evaluate(self, label_list, seq_len_list, data, epoch=None): """ :param label_list: :param seq_len_list: :param data: :param epoch: :return: """ label2tag = {} for tag, label in self.tag2label.items(): label2tag[label] = tag if label != 0 else label model_predict = [] for label_, (sent, tag) in zip(label_list, data): tag_ = [label2tag[label__] for label__ in label_] sent_res = [] if len(label_) != len(sent): print(sent) print(len(label_)) print(tag) for i in range(len(sent)): sent_res.append([sent[i], tag[i], tag_[i]]) model_predict.append(sent_res) epoch_num = str(epoch+1) if epoch != None else 'test' label_path = os.path.join(self.result_path, 'label_' + epoch_num) metric_path = os.path.join(self.result_path, 'result_metric_' + epoch_num) for _ in conlleval(model_predict, label_path, metric_path): self.logger.info(_)
#第三步 import logging, sys, argparse def str2bool(v): # copy from StackOverflow if v.lower() in ('yes', 'true', 't', 'y', '1'): return True elif v.lower() in ('no', 'false', 'f', 'n', '0'): return False else: #首先被內層IOError異常捕獲,列印“inner exception”, 然後把相同的異常再丟擲, #被外層的except捕獲,列印"outter exception" raise argparse.ArgumentTypeError('Boolean value expected.') #根據輸入的tag返回對應的字元 def get_entity(tag_seq, char_seq): PER = get_PER_entity(tag_seq, char_seq) LOC = get_LOC_entity(tag_seq, char_seq) ORG = get_ORG_entity(tag_seq, char_seq) return PER, LOC, ORG #輸出PER對應的字元 def get_PER_entity(tag_seq, char_seq): length = len(char_seq) PER = [] #構成一個zip物件,形狀類似[( 1, ),( 1, ),( 2, ),( 2, )] #zip函式可以接受一系列的可迭代物件作為引數,將物件中對應的元素打包成一個個tuple(元組), #在zip函式的括號裡面加上*號,則是zip函式的逆操作 for i, (char, tag) in enumerate(zip(char_seq, tag_seq)): #tag裡包含了O,B-PER,I-PER,B-LOCI-PER,B-ORG,I-PER if tag == 'B-PER': if 'per' in locals().keys(): PER.append('per') del per per = char if i+1 == length: PER.append(per) if tag == 'I-PER': per += char if i+1 == length: PER.append(per) if tag not in ['I-PER', 'B-PER']: if 'per' in locals().keys(): PER.append(per) del per continue return PER #輸出LOC對應的字元 def get_LOC_entity(tag_seq, char_seq): length = len(char_seq) LOC = [] for i, (char, tag) in enumerate(zip(char_seq, tag_seq)): if tag == 'B-LOC': if 'loc' in locals().keys(): LOC.append('loc') del loc loc = char if i+1 == length: LOC.append(loc) if tag == 'I-LOC': loc += char if i+1 == length: LOC.append(loc) if tag not in ['I-LOC', 'B-LOC']: if 'loc' in locals().keys(): LOC.append(loc) del loc continue return LOC #輸出ORG對應的字元 def get_ORG_entity(tag_seq, char_seq): length = len(char_seq) ORG = [] for i, (char, tag) in enumerate(zip(char_seq, tag_seq)): if tag == 'B-ORG': if 'org' in locals().keys(): ORG.append('org') del org org = char if i+1 == length: ORG.append(org) if tag == 'I-ORG': org += char if i+1 == length: ORG.append(org) if tag not in ['I-ORG', 'B-ORG']: if 'org' in locals().keys(): ORG.append(org) del org continue return ORG #記錄日誌 def get_logger(filename): logger = logging.getLogger('logger') logger.setLevel(logging.DEBUG) logging.basicConfig(format='%(message)s', level=logging.DEBUG) handler = logging.FileHandler(filename) handler.setLevel(logging.DEBUG) handler.setFormatter(logging.Formatter('%(asctime)s:%(levelname)s: %(message)s')) logging.getLogger().addHandler(handler) return logger
#第四步
import os
#使用conlleval.pl對CRF測試結果進行評價的方法
def conlleval(label_predict, label_path, metric_path):
"""
:param label_predict:
:param label_path:
:param metric_path:
:return:
"""
eval_perl = "./conlleval_rev.pl"
with open(label_path, "w") as fw:
line = []
for sent_result in label_predict:
for char, tag, tag_ in sent_result:
tag = '0' if tag == 'O' else tag
char = char.encode("utf-8")
line.append("{} {} {}\n".format(char, tag, tag_))
line.append("\n")
fw.writelines(line)
os.system("perl {} < {} > {}".format(eval_perl, label_path, metric_path))
with open(metric_path) as fr:
metrics = [line.strip() for line in fr]
return metrics
#第五步執行
import tensorflow as tf
import numpy as np
##os模組就是對作業系統進行操作
import os, argparse, time, random
from model import BiLSTM_CRF
from utils import str2bool, get_logger, get_entity
from data import read_corpus, read_dictionary, tag2label, random_embedding
## Session configuration
#在python程式碼中設定使用的GPU
os.environ['CUDA_VISIBLE_DEVICES'] = '0'
#log 日誌級別設定,只顯示 warning 和 Error,'1' 是預設的顯示等級,顯示所有資訊
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # default: 0
#記錄裝置指派情況:tf.ConfigProto(log_device_placement=True)
#設定tf.ConfigProto()中引數log_device_placement = True ,
#可以獲取到 operations 和 Tensor 被指派到哪個裝置(幾號CPU或幾號GPU)上執行,
#會在終端打印出各項操作是在哪個裝置上執行的。
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
config.gpu_options.per_process_gpu_memory_fraction = 0.2 # need ~700MB GPU memory
## hyperparameters超引數設定
#使用argparse的第一步就是建立一個解析器物件,並告訴它將會有些什麼引數。
#那麼當你的程式執行時,該解析器就可以用於處理命令列引數
parser = argparse.ArgumentParser(description='BiLSTM-CRF for Chinese NER task')
parser.add_argument('--train_data', type=str, default='data_path', help='train data source')
parser.add_argument('--test_data', type=str, default='data_path', help='test data source')
parser.add_argument('--batch_size', type=int, default=64, help='#sample of each minibatch')
parser.add_argument('--epoch', type=int, default=40, help='#epoch of training')
parser.add_argument('--hidden_dim', type=int, default=300, help='#dim of hidden state')
parser.add_argument('--optimizer', type=str, default='Adam', help='Adam/Adadelta/Adagrad/RMSProp/Momentum/SGD')
parser.add_argument('--CRF', type=str2bool, default=True, help='use CRF at the top layer. if False, use Softmax')
parser.add_argument('--lr', type=float, default=0.001, help='learning rate')
parser.add_argument('--clip', type=float, default=5.0, help='gradient clipping')
parser.add_argument('--dropout', type=float, default=0.5, help='dropout keep_prob')
parser.add_argument('--update_embedding', type=str2bool, default=True, help='update embedding during training')
parser.add_argument('--pretrain_embedding', type=str, default='random', help='use pretrained char embedding or init it randomly')
parser.add_argument('--embedding_dim', type=int, default=300, help='random init char embedding_dim')
parser.add_argument('--shuffle', type=str2bool, default=True, help='shuffle training data before each epoch')
parser.add_argument('--mode', type=str, default='demo', help='train/test/demo')
parser.add_argument('--demo_model', type=str, default='1521112368', help='model for test and demo')
#傳遞引數送入模型中
args = parser.parse_args()
#get char embeddings
'''word2id的形狀為{'當': 1, '希': 2, '望': 3, '工': 4, '程': 5,。。'<UNK>': 3904, '<PAD>': 0}
train_data總共3903個去重後的字'''
word2id = read_dictionary(os.path.join('.', args.train_data, 'word2id.pkl'))
#通過呼叫random_embedding函式返回一個len(vocab)*embedding_dim=3905*300的矩陣(矩陣元素均在-0.25到0.25之間)作為初始值
if args.pretrain_embedding == 'random':
embeddings = random_embedding(word2id, args.embedding_dim)
else:
embedding_path = 'pretrain_embedding.npy'
embeddings = np.array(np.load(embedding_path), dtype='float32')
# read corpus and get training data
if args.mode != 'demo':
#設定train_path的路徑為data_path下的train_data檔案
train_path = os.path.join('.', args.train_data, 'train_data')
#設定test_path的路徑為data_path下的test_path檔案
test_path = os.path.join('.', args.test_data, 'test_data')
#通過read_corpus函式讀取出train_data
""" train_data的形狀為[(['我',在'北','京'],['O','O','B-LOC','I-LOC'])...第一句話
(['我',在'天','安','門'],['O','O','B-LOC','I-LOC','I-LOC'])...第二句話
( 第三句話 ) ] 總共有50658句話"""
train_data = read_corpus(train_path)
test_data = read_corpus(test_path); test_size = len(test_data)
## paths setting
paths = {}
# 時間戳就是一個時間點,一般就是為了在同步更新的情況下提高效率之用。
#就比如一個檔案,如果他沒有被更改,那麼他的時間戳就不會改變,那麼就沒有必要寫回,以提高效率,
#如果不論有沒有被更改都重新寫回的話,很顯然效率會有所下降。
timestamp = str(int(time.time())) if args.mode == 'train' else args.demo_model
#輸出路徑output_path路徑設定為data_path_save下的具體時間名字為檔名
output_path = os.path.join('.', args.train_data+"_save", timestamp)
if not os.path.exists(output_path): os.makedirs(output_path)
#summary_path的路徑設定為output_path下的summaries檔案
summary_path = os.path.join(output_path, "summaries")
paths['summary_path'] = summary_path
if not os.path.exists(summary_path): os.makedirs(summary_path)
#model_path的路徑設定為output_path下的checkpoints檔案
model_path = os.path.join(output_path, "checkpoints/")
if not os.path.exists(model_path): os.makedirs(model_path)
#ckpt_prefix儲存在checkpoints下的名為model的檔案
ckpt_prefix = os.path.join(model_path, "model")
paths['model_path'] = ckpt_prefix
#result_path的路徑為時間戳檔案下的results檔案
result_path = os.path.join(output_path, "results")
paths['result_path'] = result_path
if not os.path.exists(result_path): os.makedirs(result_path)
#log_path='/results/log.txt'
log_path = os.path.join(result_path, "log.txt")
paths['log_path'] = log_path
get_logger(log_path).info(str(args))
## training model
if args.mode == 'train':
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
#建立節點,無返回值
model.build_graph()
## hyperparameters-tuning, split train/dev
# dev_data = train_data[:5000]; dev_size = len(dev_data)
# train_data = train_data[5000:]; train_size = len(train_data)
# print("train data: {0}\ndev data: {1}".format(train_size, dev_size))
# model.train(train=train_data, dev=dev_data)
## train model on the whole training data
print("train data: {}".format(len(train_data)))
# use test_data as the dev_data to see overfitting phenomena
model.train(train=train_data, dev=test_data)
## testing model
elif args.mode == 'test':
ckpt_file = tf.train.latest_checkpoint(model_path)
print(ckpt_file)
paths['model_path'] = ckpt_file
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
model.build_graph()
print("test data: {}".format(test_size))
model.test(test_data)
## demo
elif args.mode == 'demo':
ckpt_file = tf.train.latest_checkpoint(model_path)
print(ckpt_file)
paths['model_path'] = ckpt_file
model = BiLSTM_CRF(args, embeddings, tag2label, word2id, paths, config=config)
model.build_graph()
saver = tf.train.Saver()
with tf.Session(config=config) as sess:
print('============= demo =============')
saver.restore(sess, ckpt_file)
#等價於while True
while(1):
print('Please input your sentence:')
#input() 函式接受一個標準輸入資料,返回為 string 型別,'我是中國人'
demo_sent = input()
#判斷輸入是否為空
if demo_sent == '' or demo_sent.isspace():
print('See you next time!')
break
else:
#去除首尾空格
demo_sent = list(demo_sent.strip())
#[(['我', '是', '中', '國', '人'], ['O', 'O', 'O', 'O', 'O'])]
demo_data = [(demo_sent, ['O'] * len(demo_sent))]
#送入模型訓練,返回每個字正確的tag['O', 'O', 'B-LOC', 'I-LOC', 'O']
tag = model.demo_one(sess, demo_data)
#根據模型計算得到的tag,輸出該tag對應的字元,比如LOC:中國
PER, LOC, ORG = get_entity(tag, demo_sent)
print('PER: {}\nLOC: {}\nORG: {}'.format(PER, LOC, ORG))
輸出結果:
Please input your sentence:
崔永元早些年向紅十字會捐過錢
PER: ['崔永元']
LOC: []
ORG: ['紅十字會']
Please input your sentence:
蔡依林在臺北的時候追求過周杰倫
PER: ['蔡依林', '周杰倫']
LOC: ['臺北']
ORG: []
關於資料集留言後分享給你。