BERT模型從訓練到部署
BERT模型從訓練到部署全流程
標籤: BERT 訓練 部署
緣起
在群裡看到許多朋友在使用BERT模型,網上多數文章只提到了模型的訓練方法,後面的生產部署及呼叫並沒有說明。 這段時間使用BERT模型完成了從資料準備到生產部署的全流程,在這裡整理出來,方便大家參考。
在下面我將以一個“手機評論的情感分類”為例子,簡要說明從訓練到部署的全部流程。最終完成後可以使用一個網頁進行互動,實時地對輸入的評論語句進行分類判斷。
基本架構
基本架構為:
graph LR
A(BERT模型服務端) --> B(API服務端)
B-->A
B --> C(應用端)
C-->B
+-------------------+ | 應用端(HTML) | +-------------------+ ^^ || VV +-------------------+ | API服務端 | +-------------------+ ^^ || VV +-------------------+ | BERT模型服務端 | +-------------------+
架構說明: BERT模型服務端 載入模型,進行實時預測的服務; 使用的是 BERT-BiLSTM-CRF-NER
API服務端 呼叫實時預測服務,為應用提供API介面的服務,用flask編寫;
應用端 最終的應用端; 我這裡使用一個HTML網頁來實現;
本專案完整原始碼地址:BERT從訓練到部署git原始碼 專案部落格地址: BERT從訓練到部署
附件: 本例中訓練完成的模型檔案.ckpt格式及.pb格式檔案,由於比較大,已放到網盤提供下載:
連結:https://pan.baidu.com/s/1DgVjRK7zicbTlAAkFp7nWw 提取碼:8iaw
如果你想跳過前面模型的訓練過程,可以直接使用訓練好的模型,來完成後面的部署。
關鍵節點
主要包括以下關鍵節點:
- 資料準備
- 模型訓練
- 模型格式轉化
- 服務端部署與啟動
- API服務編寫與部署
- 客戶端(網頁端的編寫與部署)
資料準備
這裡用的資料是手機的評論,資料比較簡單,三個分類: -1,0,1 表示負面,中性與正面情感 資料格式如下:
1 手機很好,漂亮時尚,贈品一般 1 手機很好。包裝也很完美,贈品也是收到貨後馬上就發貨了 1 第一次在第三方買的手機 開始很擔心 不過查一下是正品 很滿意 1 很不錯 續航好 系統流暢 1 不知道真假,相信店家吧 1 快遞挺快的,榮耀10手感還是不錯的,玩了會王者還不錯,就是前後玻璃, 1 流很快,手機到手感覺很酷,白色適合女士,很驚豔!常好,執行速度快,流暢! 1 用了一天才來評價,都還可以,很滿意 1 幻影藍很好看啊,炫彩系列時尚時尚最時尚,速度快,配送執行?做活動優惠買的,開心? 1 快遞速度快,很贊!軟體更新到最新版。安裝上軟膠保護套拿手上不容易滑落。 0 手機出廠貼膜好薄啊,感覺像塑料膜。其他不能發表 0 用了一段時間,除了手機續航其它還不錯。 0 做工一般 1 挺好的,贊一個,手機很好,很喜歡 0 手機還行,但是手機剛開箱時螢幕和背面有很多指紋痕跡,手機殼跟**在地上磨過似的,好幾條印子。要不是看在能把這些痕跡擦掉,和閒退貨麻煩,就給退了。就不能規規矩矩做生意麼。還有送的都是什麼吊東西,運動手環垃圾一比,貼在手機後面的固定手環還**是塑料的渡了一層銀色,耳機也和圖片描述不符,碎屏險已經註冊,不知道怎麼樣。講真的,要不就別送或者少送,要不,就規規矩矩的,不然到最後還讓人覺得不舒服。其他沒什麼。 -1 手機整體還可以,拍照也很清楚,也很流暢支援華為。給一星是因為有缺陷,送的耳機是壞的!評論區好評太多,需要一些差評來提醒下,以後更加註意細節,提升質量。 0 前天剛買的, 看著還行, 指紋解鎖反應不錯。 1 高階大氣上檔次。 -1 各位小主,注意啦,耳機是沒有的,需要單獨買 0 外觀不錯,感覺很耗電啊,在使用段時間評價 1 手機非常好,很好用 -1 沒有發票,圖片與實物不一致 1 習慣在京東採購物品,方便快捷,及時開發票進行報銷,配送員服務也很周到!就是手機收到時沒有電,感覺不大正常 1 高階大氣上檔次啊!看電影玩遊戲估計很爽!螢幕夠大!
資料總共8097條,按6:2:2的比例拆分成train.tsv,test.tsv ,dev.tsv三個資料檔案
模型訓練
訓練模型就直接使用BERT的分類方法,把原來的run_classifier.py
複製出來並修改為 run_mobile.py
。關於訓練的程式碼網上很多,就不展開說明了,主要有以下方法:
#-----------------------------------------
#手機評論情感分類資料處理 2019/3/12
#labels: -1負面 0中性 1正面
class SetimentProcessor(DataProcessor):
def get_train_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")
def get_dev_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")
def get_test_examples(self, data_dir):
"""See base class."""
return self._create_examples(
self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")
def get_labels(self):
"""See base class."""
"""
if not os.path.exists(os.path.join(FLAGS.output_dir, 'label_list.pkl')):
with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'wb') as fd:
pickle.dump(self.labels, fd)
"""
return ["-1", "0", "1"]
def _create_examples(self, lines, set_type):
"""Creates examples for the training and dev sets."""
examples = []
for (i, line) in enumerate(lines):
if i == 0:
continue
guid = "%s-%s" % (set_type, i)
#debug (by xmxoxo)
#print("read line: No.%d" % i)
text_a = tokenization.convert_to_unicode(line[1])
if set_type == "test":
label = "0"
else:
label = tokenization.convert_to_unicode(line[0])
examples.append(
InputExample(guid=guid, text_a=text_a, label=label))
return examples
#-----------------------------------------
然後新增一個方法:
processors = {
"cola": ColaProcessor,
"mnli": MnliProcessor,
"mrpc": MrpcProcessor,
"xnli": XnliProcessor,
"setiment": SetimentProcessor, #2019/3/27 add by Echo
}
特別說明,這裡有一點要注意,在後期部署的時候,需要一個label2id的字典,所以要在訓練的時候就儲存起來,在convert_single_example
這個方法裡增加一段:
#--- save label2id.pkl ---
#在這裡輸出label2id.pkl , add by xmxoxo 2019/2/27
output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl")
if not os.path.exists(output_label2id_file):
with open(output_label2id_file,'wb') as w:
pickle.dump(label_map,w)
#--- Add end ---
這樣訓練後就會生成這個檔案了。
使用以下命令訓練模型,目錄引數請根據各自的情況修改:
cd /mnt/sda1/transdat/bert-demo/bert/
export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12
export GLUE_DIR=/mnt/sda1/transdat/bert-demo/bert/data
export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output
export EXP_NAME=mobile_0
sudo python run_mobile.py \
--task_name=setiment \
--do_train=true \
--do_eval=true \
--data_dir=$GLUE_DIR/$EXP_NAME \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
--max_seq_length=128 \
--train_batch_size=32 \
--learning_rate=2e-5 \
--num_train_epochs=5.0 \
--output_dir=$TRAINED_CLASSIFIER/$EXP_NAME
由於資料比較小,訓練是比較快的,訓練完成後,可以在輸出目錄得到模型檔案,這裡的模型檔案格式是.ckpt的。 訓練結果:
eval_accuracy = 0.861643
eval_f1 = 0.9536328
eval_loss = 0.56324786
eval_precision = 0.9491279
eval_recall = 0.9581805
global_step = 759
loss = 0.5615213
可以使用以下語句來進行預測:
sudo python run_mobile.py \
--task_name=setiment \
--do_predict=true \
--data_dir=$GLUE_DIR/$EXP_NAME \
--vocab_file=$BERT_BASE_DIR/vocab.txt \
--bert_config_file=$BERT_BASE_DIR/bert_config.json \
--init_checkpoint=$TRAINED_CLASSIFIER/$EXP_NAME \
--max_seq_length=128 \
--output_dir=$TRAINED_CLASSIFIER/$EXP_NAME
模型格式轉化
到這裡我們已經訓練得到了模型,但這個模型是.ckpt的檔案格式,檔案比較大,並且有三個檔案:
-rw-r--r-- 1 root root 1227239468 Apr 15 17:46 model.ckpt-759.data-00000-of-00001
-rw-r--r-- 1 root root 22717 Apr 15 17:46 model.ckpt-759.index
-rw-r--r-- 1 root root 3948381 Apr 15 17:46 model.ckpt-759.meta
可以看到,模板檔案非常大,大約有1.17G。
後面使用的模型服務端,使用的是.pb格式的模型檔案,所以需要把生成的ckpt格式模型檔案轉換成.pb格式的模型檔案。
我這裡提供了一個轉換工具:freeze_graph.py
,使用如下:
usage: freeze_graph.py [-h] -bert_model_dir BERT_MODEL_DIR -model_dir
MODEL_DIR [-model_pb_dir MODEL_PB_DIR]
[-max_seq_len MAX_SEQ_LEN] [-num_labels NUM_LABELS]
[-verbose]
這裡要注意的引數是:
model_dir
就是訓練好的.ckpt檔案所在的目錄max_seq_len
要與原來一致;num_labels
是分類標籤的個數,本例中是3個
python freeze_graph.py \
-bert_model_dir $BERT_BASE_DIR \
-model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
-max_seq_len 128 \
-num_labels 3
執行成功後可以看到在model_dir
目錄會生成一個classification_model.pb
檔案。
轉為.pb格式的模型檔案,同時也可以縮小模型檔案的大小,可以看到轉化後的模型檔案大約是390M。
-rw-rw-r-- 1 hexi hexi 409326375 Apr 15 17:58 classification_model.pb
服務端部署與啟動
現在可以安裝服務端了,使用的是 bert-base, 來自於專案BERT-BiLSTM-CRF-NER
, 服務端只是該專案中的一個部分。
專案地址:https://github.com/macanv/BERT-BiLSTM-CRF-NER ,感謝Macanv同學提供這麼好的專案。
這裡要說明一下,我們經常會看到bert-as-service 這個專案的介紹,它只能載入BERT的預訓練模型,輸出文字向量化的結果。 而如果要載入fine-turing後的模型,就要用到 bert-base 了,詳請請見: 基於BERT預訓練的中文命名實體識別TensorFlow實現
下載程式碼並安裝 :
pip install bert-base==0.0.7 -i https://pypi.python.org/simple
或者
git clone https://github.com/macanv/BERT-BiLSTM-CRF-NER
cd BERT-BiLSTM-CRF-NER/
python3 setup.py install
使用 bert-base 有三種執行模式,分別支援三種模型,使用引數-mode
來指定:
- NER 序列標註型別,比如命名實體識別;
- CLASS 分類模型,就是本文中使用的模型
- BERT 這個就是跟bert-as-service 一樣的模式了
之所以要分成不同的執行模式,是因為不同模型對輸入內容的預處理是不同的,命名實體識別NER是要進行序列標註; 而分類模型只要返回label就可以了。
安裝完後執行服務,同時指定監聽 HTTP 8091埠,並使用GPU 1來跑;
cd /mnt/sda1/transdat/bert-demo/bert/bert_svr
export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12
export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output
export EXP_NAME=mobile_0
bert-base-serving-start \
-model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
-bert_model_dir $BERT_BASE_DIR \
-model_pb_dir $TRAINED_CLASSIFIER/$EXP_NAME \
-mode CLASS \
-max_seq_len 128 \
-http_port 8091 \
-port 5575 \
-port_out 5576 \
-device_map 1
注意:port 和 port_out 這兩個引數是API呼叫的埠號, 預設是5555和5556,如果你準備部署多個模型服務例項,那一定要指定自己的埠號,避免衝突。 我這裡是改為: 5575 和 5576
如果報錯沒執行起來,可能是有些模組沒裝上,都是 bert_base/server/http.py裡引用的,裝上就好了:
sudo pip install flask
sudo pip install flask_compress
sudo pip install flask_cors
sudo pip install flask_json
我這裡的配置是2個GTX 1080 Ti,這個時候雙卡的優勢終於發揮出來了,GPU 1用於預測,GPU 0還可以繼續訓練模型。
執行服務後會自動生成很多臨時的目錄和檔案,為了方便管理與啟動,可建立一個工作目錄,並把啟動命令寫成一個shell指令碼。
這裡建立的是mobile_svr\bertsvr.sh
,這樣可以比較方便地設定伺服器啟動時自動啟動服務,另外增加了每次啟動時自動清除臨時檔案
程式碼如下:
#!/bin/bash
#chkconfig: 2345 80 90
#description: 啟動BERT分類模型
echo '正在啟動 BERT mobile svr...'
cd /mnt/sda1/transdat/bert-demo/bert/mobile_svr
sudo rm -rf tmp*
export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12
export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output
export EXP_NAME=mobile_0
bert-base-serving-start \
-model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
-bert_model_dir $BERT_BASE_DIR \
-model_pb_dir $TRAINED_CLASSIFIER/$EXP_NAME \
-mode CLASS \
-max_seq_len 128 \
-http_port 8091 \
-port 5575 \
-port_out 5576 \
-device_map 1
補充說明一下記憶體的使用情況: BERT在訓練時需要載入完整的模型資料,要用的記憶體是比較多的,差不多要10G,我這裡用的是GTX 1080 Ti 11G。 但在訓練完後,按上面的方式部署載入pb模型檔案時,就不需要那麼大了,上面也可以看到pb模型檔案就是390M。 其實只要你使用的是BERT base 預訓練模型,最終的得到的pb檔案大小都是差不多的。
還有同學問到能不能用CPU來部署,我這裡沒嘗試過,但我想肯定是可以的,只是在計算速度上跟GPU會有差別。
我這裡使用GPU 1來實時預測,同時載入了2個BERT模型,截圖如下:
埠測試
模型服務端部署完成了,可以使用curl命令來測試一下它的執行情況。
curl -X POST http://192.168.15.111:8091/encode \
-H 'content-type: application/json' \
-d '{"id": 111,"texts": ["總的來說,這款手機價效比是特別高的。","槽糕的售後服務!!!店大欺客"], "is_tokenized": false}'
執行結果:
> -H 'content-type: application/json' \
> -d '{"id": 111,"texts": ["總的來說,這款手機價效比是特別高的。","槽糕的售後服務!!!店大欺客"], "is_tokenized": false}'
{"id":111,"result":[{"pred_label":["1","-1"],"score":[0.9974544644355774,0.9961422085762024]}],"status":200}
可以看到對應的兩個評論,預測結果一個是1,另一個是-1,計算的速度還是非常很快的。 通過這種方式來呼叫還是不太方便,知道了這個通訊方式,我們可以用flask編寫一個API服務, 為所有的應用統一提供服務。
API服務編寫與部署
為了方便客戶端的呼叫,同時也為了可以對多個語句進行預測,我們用flask編寫一個API服務端,使用更簡潔的方式來與客戶端(應用)來通訊。
整個API服務端放在獨立目錄/mobile_apisvr/
目錄下。
用flask建立服務端並呼叫主方法,命令列引數如下:
def main_cli ():
pass
parser = argparse.ArgumentParser(description='API demo server')
parser.add_argument('-ip', type=str, default="0.0.0.0",
help='chinese google bert model serving')
parser.add_argument('-port', type=int, default=8910,
help='listen port,default:8910')
args = parser.parse_args()
flask_server(args)
主方法裡建立APP物件:
app.run(
host = args.ip, #'0.0.0.0',
port = args.port, #8910,
debug = True
)
這裡的介面簡單規劃為/api/v0.1/query
, 使用POST方法,引數名為'text',使用JSON返回結果;
路由配置:
@app.route('/api/v0.1/query', methods=['POST'])
API服務端的核心方法,是與BERT-Serving進行通訊,需要建立一個客戶端BertClient:
#對句子進行預測識別
def class_pred(list_text):
#文字拆分成句子
#list_text = cut_sent(text)
print("total setance: %d" % (len(list_text)) )
with BertClient(ip='192.168.15.111', port=5575, port_out=5576, show_server_config=False, check_version=False, check_length=False,timeout=10000 , mode='CLASS') as bc:
start_t = time.perf_counter()
rst = bc.encode(list_text)
print('result:', rst)
print('time used:{}'.format(time.perf_counter() - start_t))
#返回結構為:
# rst: [{'pred_label': ['0', '1', '0'], 'score': [0.9983683228492737, 0.9988993406295776, 0.9997349381446838]}]
#抽取出標註結果
pred_label = rst[0]["pred_label"]
result_txt = [ [pred_label[i],list_text[i] ] for i in range(len(pred_label))]
return result_txt
注意:這裡的IP,埠要與服務端的對應。
執行API 服務端:
python api_service.py
在程式碼中的debug設定為True,這樣只要更新檔案,服務就會自動重新啟動,比較方便除錯。 執行截圖如下:
到這一步也可以使用curl或者其它工具進行測試,也可以等完成網頁客戶端後一併除錯。 我這裡使用chrome外掛 API-debug來進行測試,如下圖:
客戶端(網頁端)
這裡使用一個HTML頁面來模擬客戶端,在實際專案中可能是具體的應用。 為了方便演示就把網頁模板與API服務端合併在一起了,在網頁端使用AJAX來與API服務端通訊。
建立模板目錄templates
,使用模板來載入一個HTML,模板檔名為index.html
。
在HTML頁面裡使用AJAX來呼叫介面,由於是在同一個伺服器,同一個埠,地址直接用/api/v0.1/query
就可以了,
在實際專案中,客戶應用端與API是分開的,則需要指定介面URL地址,同時還要注意資料安全性。
程式碼如下:
function UrlPOST(txt,myfun){
if (txt=="")
{
return "error parm";
}
var httpurl = "/api/v0.1/query";
$.ajax({
type: "POST",
data: "text="+txt,
url: httpurl,
//async:false,
success: function(data)
{
myfun(data);
}
});
}
啟動API服務端後,可以使用IP+埠
來訪問了,這裡的地址是http://192.168.15.111:8910/
執行介面截圖如下:
可以看到請求的用時時間為37ms,速度還是很快的,當然這個速度跟硬體配置有關。
參考資料:
歡迎批評指正,聯絡郵箱([email protected]