NLP(十七)利用tensorflow-serving部署kashgari模型
在文章NLP(十五)讓模型來告訴你文字中的時間中,我們已經學會了如何利用kashgari模組來完成序列標註模型的訓練與預測,在本文中,我們將會了解如何tensorflow-serving來部署模型。
在kashgari的官方文件中,已經有如何利用tensorflow-serving來部署模型的說明了,網址為:https://kashgari.bmio.net/advance-use/tensorflow-serving/ 。
下面,本文將介紹tensorflow-serving以及如何利用tensorflow-serving來部署kashgari的模型。
tensorflow-serving
TensorFlow Serving 是一個用於機器學習模型 serving 的高效能開源庫。它可以將訓練好的機器學習模型部署到線上,使用 gRPC 作為介面接受外部呼叫。更加讓人眼前一亮的是,它支援模型熱更新與自動模型版本管理。這意味著一旦部署 TensorFlow Serving 後,你再也不需要為線上服務操心,只需要關心你的線下模型訓練。
TensorFlow Serving可以方便我們部署TensorFlow模型,本文將使用TensorFlow Serving的Docker映象來使用TensorFlow Serving,安裝的命令如下:
docker pull tensorflow/serving
工程實踐
本專案將演示如何利用tensorflow/serving來部署kashgari中的模型,專案結構如下:
本專案的data來自之前筆者標註的時間資料集,即標註出文本中的時間,採用BIO標註系統。chinese_wwm_ext資料夾為哈工大的預訓練模型檔案。
model_train.py為模型訓練的程式碼,主要功能是完成時間序列標註模型的訓練,完整的程式碼如下:
# -*- coding: utf-8 -*- # time: 2019-09-12 # place: Huangcun Beijing import kashgari from kashgari import utils from kashgari.corpus import DataReader from kashgari.embeddings import BERTEmbedding from kashgari.tasks.labeling import BiLSTM_CRF_Model # 模型訓練 train_x, train_y = DataReader().read_conll_format_file('./data/time.train') valid_x, valid_y = DataReader().read_conll_format_file('./data/time.dev') test_x, test_y = DataReader().read_conll_format_file('./data/time.test') bert_embedding = BERTEmbedding('chinese_wwm_ext_L-12_H-768_A-12', task=kashgari.LABELING, sequence_length=128) model = BiLSTM_CRF_Model(bert_embedding) model.fit(train_x, train_y, valid_x, valid_y, batch_size=16, epochs=1) # Save model utils.convert_to_saved_model(model, model_path='saved_model/time_entity', version=1)
執行該程式碼,模型訓練完後會生成saved_model資料夾,裡面含有模型訓練好後的檔案,方便我們利用tensorflow/serving進行部署。接著我們利用tensorflow/serving來完成模型的部署,命令如下:
docker run -t --rm -p 8501:8501 -v "/Users/jclian/PycharmProjects/kashgari_tf_serving/saved_model:/models/" -e MODEL_NAME=time_entity tensorflow/serving
其中需要注意該模型所在的路徑,路徑需要寫完整路徑,以及模型的名稱(MODEL_NAME),這在訓練程式碼(train.py)中已經給出(saved_model/time_entity)。
接著我們使用tornado來搭建HTTP服務,幫助我們方便地進行模型預測,runServer.py的完整程式碼如下:
# -*- coding: utf-8 -*-
import requests
from kashgari import utils
import numpy as np
from model_predict import get_predict
import json
import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
from tornado.options import define, options
import traceback
# tornado高併發
import tornado.web
import tornado.gen
import tornado.concurrent
from concurrent.futures import ThreadPoolExecutor
# 定義埠為12333
define("port", default=16016, help="run on the given port", type=int)
# 模型預測
class ModelPredictHandler(tornado.web.RequestHandler):
executor = ThreadPoolExecutor(max_workers=5)
# get 函式
@tornado.gen.coroutine
def get(self):
origin_text = self.get_argument('text')
result = yield self.function(origin_text)
self.write(json.dumps(result, ensure_ascii=False))
@tornado.concurrent.run_on_executor
def function(self, text):
try:
text = text.replace(' ', '')
x = [_ for _ in text]
# Pre-processor data
processor = utils.load_processor(model_path='saved_model/time_entity/1')
tensor = processor.process_x_dataset([x])
# only for bert Embedding
tensor = [{
"Input-Token:0": i.tolist(),
"Input-Segment:0": np.zeros(i.shape).tolist()
} for i in tensor]
# predict
r = requests.post("http://localhost:8501/v1/models/time_entity:predict", json={"instances": tensor})
preds = r.json()['predictions']
# Convert result back to labels
labels = processor.reverse_numerize_label_sequences(np.array(preds).argmax(-1))
entities = get_predict('TIME', text, labels[0])
return entities
except Exception:
self.write(traceback.format_exc().replace('\n', '<br>'))
# get請求
class HelloHandler(tornado.web.RequestHandler):
def get(self):
self.write('Hello from lmj from Daxing Beijing!')
# 主函式
def main():
# 開啟tornado服務
tornado.options.parse_command_line()
# 定義app
app = tornado.web.Application(
handlers=[(r'/model_predict', ModelPredictHandler),
(r'/hello', HelloHandler),
], #網頁路徑控制
)
http_server = tornado.httpserver.HTTPServer(app)
http_server.listen(options.port)
tornado.ioloop.IOLoop.instance().start()
main()
我們定義了tornado封裝HTTP服務來進行模型預測,執行該指令碼,啟動模型預測的HTTP服務。接著我們再使用Python指令碼才測試下模型的預測效果以及預測時間,預測的程式碼指令碼的完整程式碼如下:
import time
import json
import requests
t1 = time.time()
texts = ['據《新聞聯播》報道,9月9日至11日,中央紀委書記趙樂際到河北調研。',
'記者從國家發展改革委、商務部相關方面獲悉,日前美方已決定對擬於10月1日實施的中國輸美商品加徵關稅措施做出調整,中方支援相關企業從即日起按照市場化原則和WTO規則,自美採購一定數量大豆、豬肉等農產品,國務院關稅稅則委員會將對上述採購予以加徵關稅排除。',
'據印度Zee新聞網站12日報道,亞洲新聞國際通訊社援引印度軍方訊息人士的話說,9月11日的對峙事件發生在靠近班公錯北岸的實際控制線一帶。',
'儋州市決定,從9月開始,對城市低保、農村低保、特困供養人員、優撫物件、領取失業保險金人員、建檔立卡未脫貧人口等低收入群體共3萬多人,發放豬肉價格補貼,每人每月發放不低於100元補貼,以後發放標準,將根據豬肉價波動情況進行動態調整。',
'9月11日,華為心聲社群釋出美國經濟學家托馬斯.弗裡德曼在《紐約時報》上的專欄內容,弗裡德曼透露,在與華為創始人任正非最近一次採訪中,任正非表示華為願意與美國司法部展開話題不設限的討論。',
'造血幹細胞移植治療白血病技術已日益成熟,然而,通過該方法同時治癒艾滋病目前還是一道全球尚在攻克的難題。',
'英國航空事故調查局(AAIB)近日披露,今年2月6日一趟由德國法蘭克福飛往墨西哥坎昆的航班上,因飛行員打翻咖啡使操作面板冒煙,導致飛機折返迫降愛爾蘭。',
'當地時間週四(9月12日),印度尼西亞財政部長英卓華(Sri Mulyani Indrawati)明確表示:特朗普的推特是風險之一。',
'華中科技大學9月12日通過其官方網站釋出通報稱,9月2日,我校一碩士研究生不幸墜樓身亡。',
'微博使用者@ooooviki 9月12日下午公佈發生在自己身上的驚悚遭遇:一個自稱網警、名叫鄭洋的人利用職務之便,查到她的完備的個人資訊,包括但不限於身份證號、家庭地址、電話號碼、戶籍變動情況等,要求她做他女朋友。',
'今天,貴陽取消了汽車限購,成為目前全國實行限購政策的9個省市中,首個取消限購的城市。',
'據悉,與全球同步,中國區此次將於9月13日於iPhone官方渠道和京東正式開啟預售,京東成Apple中國區唯一官方授權預售渠道。',
'根據央行公佈的資料,截至2019年6月末,存款類金融機構住戶部門短期消費貸款規模為9.11萬億元,2019年上半年該項淨增3293.19億元,上半年增量看起來並不樂觀。',
'9月11日,一段拍攝浙江萬里學院學生食堂的視訊走紅網路,視訊顯示該學校食堂不僅在用餐區域設定了可以看電影、比賽的大螢幕,還推出了“一人食”餐位。',
'當日,在北京舉行的2019年國際籃聯籃球世界盃半決賽中,西班牙隊對陣澳大利亞隊。',
]
print(len(texts))
for text in texts:
url = 'http://localhost:16016/model_predict?text=%s' % text
req = requests.get(url)
print(json.loads(req.content))
t2 = time.time()
print(round(t2-t1, 4))
執行該程式碼,輸出的結果如下:(預測文字中的時間)
一共預測15個句子。
['9月9日至11日']
['日前', '10月1日', '即日']
['12日', '9月11日']
['9月']
['9月11日']
[]
['近日', '今年2月6日']
['當地時間週四(9月12日)']
['9月12日', '9月2日']
['9月12日下午']
['今天', '目前']
['9月13日']
['2019年6月末', '2019年上半年', '上半年']
['9月11日']
['當日', '2019年']
預測耗時: 15.1085s.
模型預測的效果還是不錯的,但平均每句話的預測時間為1秒多,模型預測時間還是稍微偏長,後續筆者將會研究如何縮短模型預測的時間。
總結
本專案主要是介紹瞭如何利用tensorflow-serving部署kashgari模型,該專案已經上傳至github,地址為:https://github.com/percent4/tensorflow-serving_4_kashgari 。
至於如何縮短模型預測的時間,筆者還需要再繼續研究,歡迎大家關注