1. 程式人生 > >NLP(十七)利用tensorflow-serving部署kashgari模型

NLP(十七)利用tensorflow-serving部署kashgari模型

  在文章NLP(十五)讓模型來告訴你文字中的時間中,我們已經學會了如何利用kashgari模組來完成序列標註模型的訓練與預測,在本文中,我們將會了解如何tensorflow-serving來部署模型。
  在kashgari的官方文件中,已經有如何利用tensorflow-serving來部署模型的說明了,網址為:https://kashgari.bmio.net/advance-use/tensorflow-serving/ 。
  下面,本文將介紹tensorflow-serving以及如何利用tensorflow-serving來部署kashgari的模型。

tensorflow-serving

   TensorFlow Serving 是一個用於機器學習模型 serving 的高效能開源庫。它可以將訓練好的機器學習模型部署到線上,使用 gRPC 作為介面接受外部呼叫。更加讓人眼前一亮的是,它支援模型熱更新與自動模型版本管理。這意味著一旦部署 TensorFlow Serving 後,你再也不需要為線上服務操心,只需要關心你的線下模型訓練。
  TensorFlow Serving可以方便我們部署TensorFlow模型,本文將使用TensorFlow Serving的Docker映象來使用TensorFlow Serving,安裝的命令如下:

docker pull tensorflow/serving

工程實踐

  本專案將演示如何利用tensorflow/serving來部署kashgari中的模型,專案結構如下:

  本專案的data來自之前筆者標註的時間資料集,即標註出文本中的時間,採用BIO標註系統。chinese_wwm_ext資料夾為哈工大的預訓練模型檔案。
  model_train.py為模型訓練的程式碼,主要功能是完成時間序列標註模型的訓練,完整的程式碼如下:

# -*- coding: utf-8 -*-
# time: 2019-09-12
# place: Huangcun Beijing

import kashgari
from kashgari import utils
from kashgari.corpus import DataReader
from kashgari.embeddings import BERTEmbedding
from kashgari.tasks.labeling import BiLSTM_CRF_Model

# 模型訓練

train_x, train_y = DataReader().read_conll_format_file('./data/time.train')
valid_x, valid_y = DataReader().read_conll_format_file('./data/time.dev')
test_x, test_y = DataReader().read_conll_format_file('./data/time.test')

bert_embedding = BERTEmbedding('chinese_wwm_ext_L-12_H-768_A-12',
                               task=kashgari.LABELING,
                               sequence_length=128)

model = BiLSTM_CRF_Model(bert_embedding)

model.fit(train_x, train_y, valid_x, valid_y, batch_size=16, epochs=1)

# Save model
utils.convert_to_saved_model(model,
                             model_path='saved_model/time_entity',
                             version=1)

  執行該程式碼,模型訓練完後會生成saved_model資料夾,裡面含有模型訓練好後的檔案,方便我們利用tensorflow/serving進行部署。接著我們利用tensorflow/serving來完成模型的部署,命令如下:

docker run -t --rm -p 8501:8501 -v "/Users/jclian/PycharmProjects/kashgari_tf_serving/saved_model:/models/" -e MODEL_NAME=time_entity tensorflow/serving

其中需要注意該模型所在的路徑,路徑需要寫完整路徑,以及模型的名稱(MODEL_NAME),這在訓練程式碼(train.py)中已經給出(saved_model/time_entity)。

  接著我們使用tornado來搭建HTTP服務,幫助我們方便地進行模型預測,runServer.py的完整程式碼如下:

# -*- coding: utf-8 -*-
import requests
from kashgari import utils
import numpy as np
from model_predict import get_predict

import json
import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
from tornado.options import define, options
import traceback

# tornado高併發
import tornado.web
import tornado.gen
import tornado.concurrent
from concurrent.futures import ThreadPoolExecutor

# 定義埠為12333
define("port", default=16016, help="run on the given port", type=int)

# 模型預測
class ModelPredictHandler(tornado.web.RequestHandler):
    executor = ThreadPoolExecutor(max_workers=5)

    # get 函式
    @tornado.gen.coroutine
    def get(self):
        origin_text = self.get_argument('text')
        result = yield self.function(origin_text)
        self.write(json.dumps(result, ensure_ascii=False))

    @tornado.concurrent.run_on_executor
    def function(self, text):
        try:
            text = text.replace(' ', '')
            x = [_ for _ in text]

            # Pre-processor data
            processor = utils.load_processor(model_path='saved_model/time_entity/1')
            tensor = processor.process_x_dataset([x])

            # only for bert Embedding
            tensor = [{
                "Input-Token:0": i.tolist(),
                "Input-Segment:0": np.zeros(i.shape).tolist()
            } for i in tensor]

            # predict
            r = requests.post("http://localhost:8501/v1/models/time_entity:predict", json={"instances": tensor})
            preds = r.json()['predictions']

            # Convert result back to labels
            labels = processor.reverse_numerize_label_sequences(np.array(preds).argmax(-1))

            entities = get_predict('TIME', text, labels[0])

            return entities

        except Exception:
            self.write(traceback.format_exc().replace('\n', '<br>'))


# get請求
class HelloHandler(tornado.web.RequestHandler):
    def get(self):
        self.write('Hello from lmj from Daxing Beijing!')


# 主函式
def main():
    # 開啟tornado服務
    tornado.options.parse_command_line()
    # 定義app
    app = tornado.web.Application(
            handlers=[(r'/model_predict', ModelPredictHandler),
                      (r'/hello', HelloHandler),
                      ], #網頁路徑控制
          )
    http_server = tornado.httpserver.HTTPServer(app)
    http_server.listen(options.port)
    tornado.ioloop.IOLoop.instance().start()

main()

  我們定義了tornado封裝HTTP服務來進行模型預測,執行該指令碼,啟動模型預測的HTTP服務。接著我們再使用Python指令碼才測試下模型的預測效果以及預測時間,預測的程式碼指令碼的完整程式碼如下:

import time
import json
import requests

t1 = time.time()
texts = ['據《新聞聯播》報道,9月9日至11日,中央紀委書記趙樂際到河北調研。',
         '記者從國家發展改革委、商務部相關方面獲悉,日前美方已決定對擬於10月1日實施的中國輸美商品加徵關稅措施做出調整,中方支援相關企業從即日起按照市場化原則和WTO規則,自美採購一定數量大豆、豬肉等農產品,國務院關稅稅則委員會將對上述採購予以加徵關稅排除。',
         '據印度Zee新聞網站12日報道,亞洲新聞國際通訊社援引印度軍方訊息人士的話說,9月11日的對峙事件發生在靠近班公錯北岸的實際控制線一帶。',
         '儋州市決定,從9月開始,對城市低保、農村低保、特困供養人員、優撫物件、領取失業保險金人員、建檔立卡未脫貧人口等低收入群體共3萬多人,發放豬肉價格補貼,每人每月發放不低於100元補貼,以後發放標準,將根據豬肉價波動情況進行動態調整。',
         '9月11日,華為心聲社群釋出美國經濟學家托馬斯.弗裡德曼在《紐約時報》上的專欄內容,弗裡德曼透露,在與華為創始人任正非最近一次採訪中,任正非表示華為願意與美國司法部展開話題不設限的討論。',
         '造血幹細胞移植治療白血病技術已日益成熟,然而,通過該方法同時治癒艾滋病目前還是一道全球尚在攻克的難題。',
         '英國航空事故調查局(AAIB)近日披露,今年2月6日一趟由德國法蘭克福飛往墨西哥坎昆的航班上,因飛行員打翻咖啡使操作面板冒煙,導致飛機折返迫降愛爾蘭。',
         '當地時間週四(9月12日),印度尼西亞財政部長英卓華(Sri Mulyani Indrawati)明確表示:特朗普的推特是風險之一。',
         '華中科技大學9月12日通過其官方網站釋出通報稱,9月2日,我校一碩士研究生不幸墜樓身亡。',
         '微博使用者@ooooviki 9月12日下午公佈發生在自己身上的驚悚遭遇:一個自稱網警、名叫鄭洋的人利用職務之便,查到她的完備的個人資訊,包括但不限於身份證號、家庭地址、電話號碼、戶籍變動情況等,要求她做他女朋友。',
         '今天,貴陽取消了汽車限購,成為目前全國實行限購政策的9個省市中,首個取消限購的城市。',
         '據悉,與全球同步,中國區此次將於9月13日於iPhone官方渠道和京東正式開啟預售,京東成Apple中國區唯一官方授權預售渠道。',
         '根據央行公佈的資料,截至2019年6月末,存款類金融機構住戶部門短期消費貸款規模為9.11萬億元,2019年上半年該項淨增3293.19億元,上半年增量看起來並不樂觀。',
         '9月11日,一段拍攝浙江萬里學院學生食堂的視訊走紅網路,視訊顯示該學校食堂不僅在用餐區域設定了可以看電影、比賽的大螢幕,還推出了“一人食”餐位。',
         '當日,在北京舉行的2019年國際籃聯籃球世界盃半決賽中,西班牙隊對陣澳大利亞隊。',
         ]

print(len(texts))

for text in texts:
    url = 'http://localhost:16016/model_predict?text=%s' % text
    req = requests.get(url)
    print(json.loads(req.content))

t2 = time.time()

print(round(t2-t1, 4))

  執行該程式碼,輸出的結果如下:(預測文字中的時間)

一共預測15個句子。
['9月9日至11日']
['日前', '10月1日', '即日']
['12日', '9月11日']
['9月']
['9月11日']
[]
['近日', '今年2月6日']
['當地時間週四(9月12日)']
['9月12日', '9月2日']
['9月12日下午']
['今天', '目前']
['9月13日']
['2019年6月末', '2019年上半年', '上半年']
['9月11日']
['當日', '2019年']
預測耗時: 15.1085s.

模型預測的效果還是不錯的,但平均每句話的預測時間為1秒多,模型預測時間還是稍微偏長,後續筆者將會研究如何縮短模型預測的時間。

總結

   本專案主要是介紹瞭如何利用tensorflow-serving部署kashgari模型,該專案已經上傳至github,地址為:https://github.com/percent4/tensorflow-serving_4_kashgari 。
  至於如何縮短模型預測的時間,筆者還需要再繼續研究,歡迎大家關注