1. 程式人生 > >BERT模型從訓練到部署

BERT模型從訓練到部署

BERT模型從訓練到部署全流程

標籤: BERT 訓練 部署

緣起

在群裡看到許多朋友在使用BERT模型,網上多數文章只提到了模型的訓練方法,後面的生產部署及呼叫並沒有說明。 這段時間使用BERT模型完成了從資料準備到生產部署的全流程,在這裡整理出來,方便大家參考。

在下面我將以一個“手機評論的情感分類”為例子,簡要說明從訓練到部署的全部流程。最終完成後可以使用一個網頁進行互動,實時地對輸入的評論語句進行分類判斷。

基本架構

基本架構為:

graph LR
A(BERT模型服務端) --> B(API服務端)
B-->A
B --> C(應用端)
C-->B
+-------------------+
|   應用端(HTML)     | 
+-------------------+
         ^^
         ||
         VV
+-------------------+
|     API服務端      | 
+-------------------+
         ^^
         ||
         VV
+-------------------+
|  BERT模型服務端    | 
+-------------------+

架構說明: BERT模型服務端     載入模型,進行實時預測的服務;     使用的是 BERT-BiLSTM-CRF-NER 

API服務端      呼叫實時預測服務,為應用提供API介面的服務,用flask編寫; 

應用端     最終的應用端;     我這裡使用一個HTML網頁來實現;

本專案完整原始碼地址:BERT從訓練到部署git原始碼 專案部落格地址: BERT從訓練到部署

附件: 本例中訓練完成的模型檔案.ckpt格式及.pb格式檔案,由於比較大,已放到網盤提供下載:

連結:https://pan.baidu.com/s/1DgVjRK7zicbTlAAkFp7nWw 
提取碼:8iaw 

如果你想跳過前面模型的訓練過程,可以直接使用訓練好的模型,來完成後面的部署。

關鍵節點

主要包括以下關鍵節點:

  • 資料準備
  • 模型訓練
  • 模型格式轉化
  • 服務端部署與啟動
  • API服務編寫與部署
  • 客戶端(網頁端的編寫與部署)

資料準備

這裡用的資料是手機的評論,資料比較簡單,三個分類: -1,0,1 表示負面,中性與正面情感 資料格式如下:

1    手機很好,漂亮時尚,贈品一般
1    手機很好。包裝也很完美,贈品也是收到貨後馬上就發貨了
1    第一次在第三方買的手機 開始很擔心 不過查一下是正品 很滿意
1    很不錯 續航好 系統流暢
1    不知道真假,相信店家吧
1    快遞挺快的,榮耀10手感還是不錯的,玩了會王者還不錯,就是前後玻璃,
1    流很快,手機到手感覺很酷,白色適合女士,很驚豔!常好,執行速度快,流暢!
1    用了一天才來評價,都還可以,很滿意
1    幻影藍很好看啊,炫彩系列時尚時尚最時尚,速度快,配送執行?做活動優惠買的,開心?
1    快遞速度快,很贊!軟體更新到最新版。安裝上軟膠保護套拿手上不容易滑落。
0    手機出廠貼膜好薄啊,感覺像塑料膜。其他不能發表
0    用了一段時間,除了手機續航其它還不錯。
0    做工一般
1    挺好的,贊一個,手機很好,很喜歡
0    手機還行,但是手機剛開箱時螢幕和背面有很多指紋痕跡,手機殼跟**在地上磨過似的,好幾條印子。要不是看在能把這些痕跡擦掉,和閒退貨麻煩,就給退了。就不能規規矩矩做生意麼。還有送的都是什麼吊東西,運動手環垃圾一比,貼在手機後面的固定手環還**是塑料的渡了一層銀色,耳機也和圖片描述不符,碎屏險已經註冊,不知道怎麼樣。講真的,要不就別送或者少送,要不,就規規矩矩的,不然到最後還讓人覺得不舒服。其他沒什麼。
-1    手機整體還可以,拍照也很清楚,也很流暢支援華為。給一星是因為有缺陷,送的耳機是壞的!評論區好評太多,需要一些差評來提醒下,以後更加註意細節,提升質量。
0    前天剛買的,  看著還行, 指紋解鎖反應不錯。
1    高階大氣上檔次。
-1    各位小主,注意啦,耳機是沒有的,需要單獨買
0    外觀不錯,感覺很耗電啊,在使用段時間評價
1    手機非常好,很好用
-1    沒有發票,圖片與實物不一致
1    習慣在京東採購物品,方便快捷,及時開發票進行報銷,配送員服務也很周到!就是手機收到時沒有電,感覺不大正常
1    高階大氣上檔次啊!看電影玩遊戲估計很爽!螢幕夠大!

資料總共8097條,按6:2:2的比例拆分成train.tsv,test.tsv ,dev.tsv三個資料檔案

模型訓練

訓練模型就直接使用BERT的分類方法,把原來的run_classifier.py 複製出來並修改為 run_mobile.py。關於訓練的程式碼網上很多,就不展開說明了,主要有以下方法:

#-----------------------------------------
#手機評論情感分類資料處理 2019/3/12 
#labels: -1負面 0中性 1正面
class SetimentProcessor(DataProcessor):
  def get_train_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

  def get_dev_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

  def get_test_examples(self, data_dir):
    """See base class."""
    return self._create_examples(
        self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

  def get_labels(self):
    """See base class."""

    """
    if not os.path.exists(os.path.join(FLAGS.output_dir, 'label_list.pkl')):
        with codecs.open(os.path.join(FLAGS.output_dir, 'label_list.pkl'), 'wb') as fd:
            pickle.dump(self.labels, fd)
    """
    return ["-1", "0", "1"]

  def _create_examples(self, lines, set_type):
    """Creates examples for the training and dev sets."""
    examples = []
    for (i, line) in enumerate(lines):
      if i == 0: 
        continue
      guid = "%s-%s" % (set_type, i)

      #debug (by xmxoxo)
      #print("read line: No.%d" % i)

      text_a = tokenization.convert_to_unicode(line[1])
      if set_type == "test":
        label = "0"
      else:
        label = tokenization.convert_to_unicode(line[0])
      examples.append(
          InputExample(guid=guid, text_a=text_a, label=label))
    return examples
#-----------------------------------------

然後新增一個方法:

  processors = {
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "setiment": SetimentProcessor, #2019/3/27 add by Echo
  }

特別說明,這裡有一點要注意,在後期部署的時候,需要一個label2id的字典,所以要在訓練的時候就儲存起來,在convert_single_example這個方法裡增加一段:

  #--- save label2id.pkl ---
  #在這裡輸出label2id.pkl , add by xmxoxo 2019/2/27
  output_label2id_file = os.path.join(FLAGS.output_dir, "label2id.pkl")
  if not os.path.exists(output_label2id_file):
    with open(output_label2id_file,'wb') as w:
      pickle.dump(label_map,w)

  #--- Add end ---

這樣訓練後就會生成這個檔案了。

使用以下命令訓練模型,目錄引數請根據各自的情況修改:

cd /mnt/sda1/transdat/bert-demo/bert/
export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12
export GLUE_DIR=/mnt/sda1/transdat/bert-demo/bert/data
export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output
export EXP_NAME=mobile_0

sudo python run_mobile.py \
  --task_name=setiment \
  --do_train=true \
  --do_eval=true \
  --data_dir=$GLUE_DIR/$EXP_NAME \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=5.0 \
  --output_dir=$TRAINED_CLASSIFIER/$EXP_NAME

由於資料比較小,訓練是比較快的,訓練完成後,可以在輸出目錄得到模型檔案,這裡的模型檔案格式是.ckpt的。 訓練結果:

eval_accuracy = 0.861643
eval_f1 = 0.9536328
eval_loss = 0.56324786
eval_precision = 0.9491279
eval_recall = 0.9581805
global_step = 759
loss = 0.5615213

可以使用以下語句來進行預測:

sudo python run_mobile.py \
  --task_name=setiment \
  --do_predict=true \
  --data_dir=$GLUE_DIR/$EXP_NAME \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER/$EXP_NAME \
  --max_seq_length=128 \
  --output_dir=$TRAINED_CLASSIFIER/$EXP_NAME

模型格式轉化

到這裡我們已經訓練得到了模型,但這個模型是.ckpt的檔案格式,檔案比較大,並且有三個檔案:

-rw-r--r-- 1 root root 1227239468 Apr 15 17:46 model.ckpt-759.data-00000-of-00001
-rw-r--r-- 1 root root      22717 Apr 15 17:46 model.ckpt-759.index
-rw-r--r-- 1 root root    3948381 Apr 15 17:46 model.ckpt-759.meta

可以看到,模板檔案非常大,大約有1.17G。 後面使用的模型服務端,使用的是.pb格式的模型檔案,所以需要把生成的ckpt格式模型檔案轉換成.pb格式的模型檔案。 我這裡提供了一個轉換工具:freeze_graph.py,使用如下:

usage: freeze_graph.py [-h] -bert_model_dir BERT_MODEL_DIR -model_dir
                       MODEL_DIR [-model_pb_dir MODEL_PB_DIR]
                       [-max_seq_len MAX_SEQ_LEN] [-num_labels NUM_LABELS]
                       [-verbose]

這裡要注意的引數是:

  • model_dir 就是訓練好的.ckpt檔案所在的目錄
  • max_seq_len 要與原來一致;
  • num_labels 是分類標籤的個數,本例中是3個
python freeze_graph.py \
    -bert_model_dir $BERT_BASE_DIR \
    -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -max_seq_len 128 \
    -num_labels 3

執行成功後可以看到在model_dir目錄會生成一個classification_model.pb 檔案。 轉為.pb格式的模型檔案,同時也可以縮小模型檔案的大小,可以看到轉化後的模型檔案大約是390M。

-rw-rw-r-- 1 hexi hexi 409326375 Apr 15 17:58 classification_model.pb

服務端部署與啟動

現在可以安裝服務端了,使用的是 bert-base, 來自於專案BERT-BiLSTM-CRF-NER, 服務端只是該專案中的一個部分。 專案地址:https://github.com/macanv/BERT-BiLSTM-CRF-NER ,感謝Macanv同學提供這麼好的專案。

這裡要說明一下,我們經常會看到bert-as-service 這個專案的介紹,它只能載入BERT的預訓練模型,輸出文字向量化的結果。 而如果要載入fine-turing後的模型,就要用到 bert-base 了,詳請請見: 基於BERT預訓練的中文命名實體識別TensorFlow實現

下載程式碼並安裝 :

pip install bert-base==0.0.7 -i https://pypi.python.org/simple

或者 

git clone https://github.com/macanv/BERT-BiLSTM-CRF-NER
cd BERT-BiLSTM-CRF-NER/
python3 setup.py install

使用 bert-base 有三種執行模式,分別支援三種模型,使用引數-mode 來指定:

  • NER      序列標註型別,比如命名實體識別;
  • CLASS    分類模型,就是本文中使用的模型
  • BERT     這個就是跟bert-as-service 一樣的模式了

之所以要分成不同的執行模式,是因為不同模型對輸入內容的預處理是不同的,命名實體識別NER是要進行序列標註; 而分類模型只要返回label就可以了。

安裝完後執行服務,同時指定監聽 HTTP 8091埠,並使用GPU 1來跑;

cd /mnt/sda1/transdat/bert-demo/bert/bert_svr

export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12
export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output
export EXP_NAME=mobile_0

bert-base-serving-start \
    -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -bert_model_dir $BERT_BASE_DIR \
    -model_pb_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -mode CLASS \
    -max_seq_len 128 \
    -http_port 8091 \
    -port 5575 \
    -port_out 5576 \
    -device_map 1 

注意:port 和 port_out 這兩個引數是API呼叫的埠號, 預設是5555和5556,如果你準備部署多個模型服務例項,那一定要指定自己的埠號,避免衝突。 我這裡是改為: 5575 和 5576

如果報錯沒執行起來,可能是有些模組沒裝上,都是 bert_base/server/http.py裡引用的,裝上就好了:

sudo pip install flask 
sudo pip install flask_compress
sudo pip install flask_cors
sudo pip install flask_json

我這裡的配置是2個GTX 1080 Ti,這個時候雙卡的優勢終於發揮出來了,GPU 1用於預測,GPU 0還可以繼續訓練模型。

執行服務後會自動生成很多臨時的目錄和檔案,為了方便管理與啟動,可建立一個工作目錄,並把啟動命令寫成一個shell指令碼。 這裡建立的是mobile_svr\bertsvr.sh ,這樣可以比較方便地設定伺服器啟動時自動啟動服務,另外增加了每次啟動時自動清除臨時檔案

程式碼如下:

#!/bin/bash
#chkconfig: 2345 80 90
#description: 啟動BERT分類模型 

echo '正在啟動 BERT mobile svr...'
cd /mnt/sda1/transdat/bert-demo/bert/mobile_svr
sudo rm -rf tmp*

export BERT_BASE_DIR=/mnt/sda1/transdat/bert-demo/bert/chinese_L-12_H-768_A-12
export TRAINED_CLASSIFIER=/mnt/sda1/transdat/bert-demo/bert/output
export EXP_NAME=mobile_0

bert-base-serving-start \
    -model_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -bert_model_dir $BERT_BASE_DIR \
    -model_pb_dir $TRAINED_CLASSIFIER/$EXP_NAME \
    -mode CLASS \
    -max_seq_len 128 \
    -http_port 8091 \
    -port 5575 \
    -port_out 5576 \
    -device_map 1 

補充說明一下記憶體的使用情況: BERT在訓練時需要載入完整的模型資料,要用的記憶體是比較多的,差不多要10G,我這裡用的是GTX 1080 Ti 11G。 但在訓練完後,按上面的方式部署載入pb模型檔案時,就不需要那麼大了,上面也可以看到pb模型檔案就是390M。 其實只要你使用的是BERT base 預訓練模型,最終的得到的pb檔案大小都是差不多的。

還有同學問到能不能用CPU來部署,我這裡沒嘗試過,但我想肯定是可以的,只是在計算速度上跟GPU會有差別。

我這裡使用GPU 1來實時預測,同時載入了2個BERT模型,截圖如下:

GPU截圖

埠測試

模型服務端部署完成了,可以使用curl命令來測試一下它的執行情況。

curl -X POST http://192.168.15.111:8091/encode \
  -H 'content-type: application/json' \
  -d '{"id": 111,"texts": ["總的來說,這款手機價效比是特別高的。","槽糕的售後服務!!!店大欺客"], "is_tokenized": false}'

執行結果:

>   -H 'content-type: application/json' \
>   -d '{"id": 111,"texts": ["總的來說,這款手機價效比是特別高的。","槽糕的售後服務!!!店大欺客"], "is_tokenized": false}'
{"id":111,"result":[{"pred_label":["1","-1"],"score":[0.9974544644355774,0.9961422085762024]}],"status":200}

可以看到對應的兩個評論,預測結果一個是1,另一個是-1,計算的速度還是非常很快的。 通過這種方式來呼叫還是不太方便,知道了這個通訊方式,我們可以用flask編寫一個API服務, 為所有的應用統一提供服務。

API服務編寫與部署

為了方便客戶端的呼叫,同時也為了可以對多個語句進行預測,我們用flask編寫一個API服務端,使用更簡潔的方式來與客戶端(應用)來通訊。 整個API服務端放在獨立目錄/mobile_apisvr/目錄下。

用flask建立服務端並呼叫主方法,命令列引數如下:

def main_cli ():
    pass
    parser = argparse.ArgumentParser(description='API demo server')
    parser.add_argument('-ip', type=str, default="0.0.0.0",
                        help='chinese google bert model serving')
    parser.add_argument('-port', type=int, default=8910,
                        help='listen port,default:8910')

    args = parser.parse_args()

    flask_server(args)

主方法裡建立APP物件:


    app.run(
        host = args.ip,     #'0.0.0.0',
        port = args.port,   #8910,  
        debug = True 
    )

這裡的介面簡單規劃為/api/v0.1/query, 使用POST方法,引數名為'text',使用JSON返回結果; 路由配置:

@app.route('/api/v0.1/query', methods=['POST'])

API服務端的核心方法,是與BERT-Serving進行通訊,需要建立一個客戶端BertClient:

#對句子進行預測識別
def class_pred(list_text):
    #文字拆分成句子
    #list_text = cut_sent(text)
    print("total setance: %d" % (len(list_text)) )
    with BertClient(ip='192.168.15.111', port=5575, port_out=5576, show_server_config=False, check_version=False, check_length=False,timeout=10000 ,  mode='CLASS') as bc:
        start_t = time.perf_counter()
        rst = bc.encode(list_text)
        print('result:', rst)
        print('time used:{}'.format(time.perf_counter() - start_t))
    #返回結構為:
    # rst: [{'pred_label': ['0', '1', '0'], 'score': [0.9983683228492737, 0.9988993406295776, 0.9997349381446838]}]
    #抽取出標註結果
    pred_label = rst[0]["pred_label"]
    result_txt = [ [pred_label[i],list_text[i] ] for i in range(len(pred_label))]
    return result_txt

注意:這裡的IP,埠要與服務端的對應。

執行API 服務端:

python api_service.py

在程式碼中的debug設定為True,這樣只要更新檔案,服務就會自動重新啟動,比較方便除錯。 執行截圖如下:

API服務端

到這一步也可以使用curl或者其它工具進行測試,也可以等完成網頁客戶端後一併除錯。 我這裡使用chrome外掛 API-debug來進行測試,如下圖:

API測試

客戶端(網頁端)

這裡使用一個HTML頁面來模擬客戶端,在實際專案中可能是具體的應用。 為了方便演示就把網頁模板與API服務端合併在一起了,在網頁端使用AJAX來與API服務端通訊。

建立模板目錄templates,使用模板來載入一個HTML,模板檔名為index.html。 在HTML頁面裡使用AJAX來呼叫介面,由於是在同一個伺服器,同一個埠,地址直接用/api/v0.1/query就可以了, 在實際專案中,客戶應用端與API是分開的,則需要指定介面URL地址,同時還要注意資料安全性。 程式碼如下:

function UrlPOST(txt,myfun){
    if (txt=="")
    {
        return "error parm"; 
    }
    var httpurl = "/api/v0.1/query"; 
    $.ajax({
            type: "POST",
            data: "text="+txt,
            url: httpurl,
            //async:false,
            success: function(data)
            {   
                myfun(data);
            }
    });
}

啟動API服務端後,可以使用IP+埠來訪問了,這裡的地址是http://192.168.15.111:8910/

執行介面截圖如下:

執行介面截圖

可以看到請求的用時時間為37ms,速度還是很快的,當然這個速度跟硬體配置有關。

參考資料:

歡迎批評指正,聯絡郵箱([email protected]