使用Bert模型的run_classifier進行Fine-Tuning
技術標籤:MLexperience
首先我們需要下載中文模型檔案,直接給出連結下載即可https://pan.baidu.com/s/1-c068UOgfhrMyIIhR5fHXg,提取碼是: 2z2r,解壓完成後會出現五個檔案,其中一個詞彙表文件vocab.txt,還有三個Bert tensorflow的模型檔案,這裡就不一一列舉了,還有一個引數設定檔案bert_config.json。接下來再去github上down下來模型就可以開始搞了!
開搞!
首先在main()下面的processors裡面建立一個自己的專案,例如我自己的建立成叫做my_bert,在後面接著定義一個類的名稱。
def main(_):
tf.logging.set_verbosity(tf.logging.INFO) #設計日誌級別
'''在這裡建立一個專案'''
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"my_bert":my_bertProcessor
}
接下來我們可以按照processors裡的其它專案進行改動,比如說MrpcProcessor(),我們可以把class MrpcProcessor(DataProcessor):整個都複製過來,然後在下面重新貼上一下即可,然後我們把下面的程式碼段進行略微的改動即可進行分類,具體怎麼改看下面介紹
class MrpcProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples( self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ["0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[3])
text_b = tokenization.convert_to_unicode(line[4])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=text_b, label=label))
return examples
我們需要將這個類的名字改成你自己定義的那個,然後把def get_labels(self):裡面的分類類別改動一下,然後在def _create_examples(self, lines, set_type):進行細微改動即可,具體改動我把我的改動給大家,看一下可以對比一下,就很明顯知道為什麼要這麼改了
class my_bertProcessor(DataProcessor):
"""Processor for the MRPC data set (GLUE version)."""
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
return ['1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35']
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
guid = "%s-%s" % (set_type, i)
text_a = tokenization.convert_to_unicode(line[1])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, text_b=None, label=label))
return examples
在這裡解釋一下為什麼要這麼改動,首先我們要做多分類,所以必須要將label改成我們所需要的,其次我們在進行文字處理的時候,和我們輸入資料的格式有關係,這裡的line的格式類似[label , context],所以我們在提取的時候需要設定一下文字和對應提取的內容,再就是我們進行文字分類的時候,我們只有一句話,所以不需要text_b,我們在下面直接設定一下None就可以了
下面就介紹一下引數的給定,我沒有按照github上面那樣給定,而是直接在檔案裡面進行設定了一下,下面程式碼給出的是我的設定
flags.DEFINE_string(
"data_dir", "data", #這裡面需要新增你自己的分類資料資料夾的名字
"The input data dir. Should contain the .tsv files (or other data files) "
"for the task.")
flags.DEFINE_string(
"bert_config_file", "BERT/bert_config.json", #這裡面加入引數檔案
"The config json file corresponding to the pre-trained BERT model. "
"This specifies the model architecture.")
flags.DEFINE_string("task_name", "my_bert", "The name of the task to train.") #這個task裡新增的是你自己的專案名字
flags.DEFINE_string("vocab_file", "BERT/vocab.txt", #解壓完後的五個檔案之一
"The vocabulary file that the BERT model was trained on.")
flags.DEFINE_string(
"output_dir", "model/", #這個裡將會儲存你接下來訓練的模型檔案和驗證的結果
"The output directory where the model checkpoints will be written.")
## Other parameters
flags.DEFINE_string(
"init_checkpoint", 'BERT/bert_model.ckpt', #這裡面新增上你下載的bert訓練好的模型檔案
"Initial checkpoint (usually from a pre-trained BERT model).")
flags.DEFINE_bool(
"do_lower_case", True,
"Whether to lower case the input text. Should be True for uncased "
"models and False for cased models.")
flags.DEFINE_integer(
"max_seq_length", 128,
"The maximum total input sequence length after WordPiece tokenization. "
"Sequences longer than this will be truncated, and sequences shorter "
"than this will be padded.")
flags.DEFINE_bool("do_train", True, "Whether to run training.") #訓練的時候訓練成True
flags.DEFINE_bool("do_eval", True, "Whether to run eval on the dev set.")
flags.DEFINE_bool(
"do_predict", True, #想要驗證就tr
"Whether to run the model in inference mode on the test set.")
flags.DEFINE_integer("train_batch_size", 32, "Total batch size for training.")
#這個batchsize越大模型效果越好,當然這取決你的機器記憶體多大
flags.DEFINE_integer("eval_batch_size", 8, "Total batch size for eval.")
flags.DEFINE_integer("predict_batch_size", 8, "Total batch size for predict.")
flags.DEFINE_float("learning_rate", 5e-5, "The initial learning rate for Adam.")
flags.DEFINE_float("num_train_epochs", 3.0,
"Total number of training epochs to perform.")
flags.DEFINE_float(
"warmup_proportion", 0.1,
"Proportion of training to perform linear learning rate warmup for. "
"E.g., 0.1 = 10% of training.")
flags.DEFINE_integer("save_checkpoints_steps", 100,
"How often to save the model checkpoint.")
flags.DEFINE_integer("iterations_per_loop", 50,
"How many steps to make in each estimator call.")
flags.DEFINE_bool("use_tpu", False, "Whether to use TPU or GPU/CPU.")
tf.flags.DEFINE_string(
"tpu_name", None,
"The Cloud TPU to use for training. This should be either the name "
"used when creating the Cloud TPU, or a grpc://ip.address.of.tpu:8470 "
"url.")
tf.flags.DEFINE_string(
"tpu_zone", None,
"[Optional] GCE zone where the Cloud TPU is located in. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string(
"gcp_project", None,
"[Optional] Project name for the Cloud TPU-enabled project. If not "
"specified, we will attempt to automatically detect the GCE project from "
"metadata.")
tf.flags.DEFINE_string("master", None, "[Optional] TensorFlow master URL.")
flags.DEFINE_integer(
"num_tpu_cores", 8,
"Only used if `use_tpu` is True. Total number of TPU cores to use.")
接下來說一下輸入資料的格式,bert模型預設讀取.tsv檔案,所以我們直接把txt文件的字尾改一下就可以,檔案裡的內容可以是這個樣子,比如說:label + \t + sentence,然後執行就可以了,檔案會自動列印,把你的訓練資料檔案改成如下圖所示:
給出一部分我自己的訓練的截圖