1. 程式人生 > >NLP(十八)利用ALBERT提升模型預測速度的一次嘗試

NLP(十八)利用ALBERT提升模型預測速度的一次嘗試

前沿

  在文章NLP(十七)利用tensorflow-serving部署kashgari模型中,筆者介紹瞭如何利用tensorflow-serving部署來部署深度模型模型,在那篇文章中,筆者利用kashgari模組實現了經典的BERT+Bi-LSTM+CRF模型結構,在標註了時間的文字語料(大約2000多個訓練句子)中也達到了很好的識別效果,但是也存在著不足之處,那就是模型的預測時間過長,平均預測一個句子中的時間耗時約400毫秒,這種預測速度在生產環境或實際應用中是不能忍受的。
  檢視該模型的耗時原因,很大一部分原因在於BERT的呼叫。BERT是當下最火,知名度最高的預訓練模型,雖然會使得模型的訓練、預測耗時增加,但也是小樣本語料下的最佳模型工具之一,因此,BERT在模型的架構上是不可缺少的。那麼,該如何避免使用預訓練模型帶來的模型預測耗時過長的問題呢?
  本文決定嘗試使用ALBERT,來驗證ALBERT在提升模型預測速度方面的應用,同時,也算是本人對於使用ALBERT的一次實戰吧~

ALBERT簡介

  我們不妨花一些時間來簡單地瞭解一下ALBERT。ALBERT是最近一週才開源的預訓練模型,其Github的網址為:https://github.com/brightmart/albert_zh ,其論文可以參考網址:https://arxiv.org/pdf/1909.11942.pdf 。
  根據ALBERT的Github介紹,ALBERT在海量中文語料上進行了預訓練,模型的引數更少,效果更好。以albert_tiny_zh為例,其檔案大小16M、引數為1.8M,模型大小僅為BERT的1/25,效果僅比BERT略差或者在某些NLP任務上更好。在本文的預訓練模型中,將採用albert_tiny_zh。

利用ALBERT訓練時間識別模型

  我們以Github中的bertNER為本次專案的程式碼模板,在該專案中,實現的模型為BERT+Bi-LSTM+CRF,我們將BERT替換為ALBERT,也就是說筆者的專案中模型為ALBERT+Bi-LSTM+CRF,同時替換bert資料夾的程式碼為alert_zh,替換預訓練模型資料夾chinese_L-12_H-768_A-12(BERT中文預訓練模型檔案)為albert_tiny。當然,也需要修改一部分的專案原始碼,來適應ALBERT的模型訓練。
  資料集採用筆者自己標註的時間語料,即標註了時間的句子,大概2000+句子,其中75%作為訓練集(time.train檔案),10%作為驗證集(time.dev檔案),15%作為測試集(time.test檔案)。在這裡筆者不打算給出具體的Python程式碼,因為工程比較複雜,有興趣的額讀者可以去檢視該專案的Github地址:。
  一些模型的引數可以如下:

  • 預訓練模型:ALBERT(tiny)
  • 訓練樣本的最大字元長度: 128
  • batch_size: 8
  • epoch: 100
  • 雙向LSTM的個數:100

  ALBERT的模型訓練時間也會顯著提高,我們耐心地等待模型訓練完畢。在time.dev和time.test資料集上的表現如下表:

資料集 precision recall f1
time.dev 81.41% 84.95% 83.14%
time.test 83.03% 86.38% 84.67%

  接著筆者利用訓練好的模型,用tornado封裝了一個模型預測的HTTP服務,具體的程式碼如下:

# -*- coding: utf-8 -*-

import os
import json
import time
import pickle
import traceback

import tornado.httpserver
import tornado.ioloop
import tornado.options
import tornado.web
from tornado.options import define, options

import tensorflow as tf
from utils import create_model, get_logger
from model import Model
from loader import input_from_line
from train import FLAGS, load_config, train

# 定義埠為12306
define("port", default=12306, help="run on the given port", type=int)
# 匯入模型
config = load_config(FLAGS.config_file)
logger = get_logger(FLAGS.log_file)
# limit GPU memory
tf_config = tf.ConfigProto()
tf_config.gpu_options.allow_growth = False
with open(FLAGS.map_file, "rb") as f:
    tag_to_id, id_to_tag = pickle.load(f)

sess = tf.Session(config=tf_config)
model = create_model(sess, Model, FLAGS.ckpt_path, config, logger)

# 模型預測的HTTP介面
class ResultHandler(tornado.web.RequestHandler):
    # post函式
    def post(self):
        event = self.get_argument('event')
        result = model.evaluate_line(sess, input_from_line(event, FLAGS.max_seq_len, tag_to_id), id_to_tag)
        self.write(json.dumps(result, ensure_ascii=False))

# 主函式
def main():
    # 開啟tornado服務
    tornado.options.parse_command_line()
    # 定義app
    app = tornado.web.Application(
            handlers=[
                      (r'/subj_extract', ResultHandler)
                     ], #網頁路徑控制
           )
    http_server = tornado.httpserver.HTTPServer(app)
    http_server.listen(options.port)
    tornado.ioloop.IOLoop.instance().start()

main()

模型預測提速了嗎?

  將模型預測封裝成HTTP服務後,我們利用Postman來測試模型預測的效果和時間,如下圖所示:

可以看到,模型預測的結果正確,且耗時僅為38ms。
  接著我們嘗試多測試幾個句子的測試,測試程式碼如下:

# Daxing, Beijing
import requests
import json
import time

url = 'http://localhost:12306/subj_extract'

texts = ['據《新聞聯播》報道,9月9日至11日,中央紀委書記趙樂際到河北調研。',
         '記者從國家發展改革委、商務部相關方面獲悉,日前美方已決定對擬於10月1日實施的中國輸美商品加徵關稅措施做出調整,中方支援相關企業從即日起按照市場化原則和WTO規則,自美採購一定數量大豆、豬肉等農產品,國務院關稅稅則委員會將對上述採購予以加徵關稅排除。',
         '據印度Zee新聞網站12日報道,亞洲新聞國際通訊社援引印度軍方訊息人士的話說,9月11日的對峙事件發生在靠近班公錯北岸的實際控制線一帶。',
         '儋州市決定,從9月開始,對城市低保、農村低保、特困供養人員、優撫物件、領取失業保險金人員、建檔立卡未脫貧人口等低收入群體共3萬多人,發放豬肉價格補貼,每人每月發放不低於100元補貼,以後發放標準,將根據豬肉價波動情況進行動態調整。',
         '9月11日,華為心聲社群釋出美國經濟學家托馬斯.弗裡德曼在《紐約時報》上的專欄內容,弗裡德曼透露,在與華為創始人任正非最近一次採訪中,任正非表示華為願意與美國司法部展開話題不設限的討論。',
         '造血幹細胞移植治療白血病技術已日益成熟,然而,通過該方法同時治癒艾滋病目前還是一道全球尚在攻克的難題。',
         '英國航空事故調查局(AAIB)近日披露,今年2月6日一趟由德國法蘭克福飛往墨西哥坎昆的航班上,因飛行員打翻咖啡使操作面板冒煙,導致飛機折返迫降愛爾蘭。',
         '當地時間週四(9月12日),印度尼西亞財政部長英卓華(Sri Mulyani Indrawati)明確表示:特朗普的推特是風險之一。',
         '華中科技大學9月12日通過其官方網站釋出通報稱,9月2日,我校一碩士研究生不幸墜樓身亡。',
         '微博使用者@ooooviki 9月12日下午公佈發生在自己身上的驚悚遭遇:一個自稱網警、名叫鄭洋的人利用職務之便,查到她的完備的個人資訊,包括但不限於身份證號、家庭地址、電話號碼、戶籍變動情況等,要求她做他女朋友。',
         '今天,貴陽取消了汽車限購,成為目前全國實行限購政策的9個省市中,首個取消限購的城市。',
         '據悉,與全球同步,中國區此次將於9月13日於iPhone官方渠道和京東正式開啟預售,京東成Apple中國區唯一官方授權預售渠道。',
         '根據央行公佈的資料,截至2019年6月末,存款類金融機構住戶部門短期消費貸款規模為9.11萬億元,2019年上半年該項淨增3293.19億元,上半年增量看起來並不樂觀。',
         '9月11日,一段拍攝浙江萬里學院學生食堂的視訊走紅網路,視訊顯示該學校食堂不僅在用餐區域設定了可以看電影、比賽的大螢幕,還推出了“一人食”餐位。',
         '當日,在北京舉行的2019年國際籃聯籃球世界盃半決賽中,西班牙隊對陣澳大利亞隊。',
         ]

t1 = time.time()
for text in texts:
    data = {'event': text.replace(' ', '')}
    req = requests.post(url, data)
    if req.status_code == 200:
        print('原文:%s' % text)
        res = json.loads(req.content)['entities']
        print('抽取結果:%s' % str([_['word'] for _ in res]))


t2 = time.time()
print('一共耗時:%ss.' % str(round(t2-t1, 4)))

輸出結果如下:

原文:據《新聞聯播》報道,9月9日至11日,中央紀委書記趙樂際到河北調研。
抽取結果:['9月9日至11日']
原文:記者從國家發展改革委、商務部相關方面獲悉,日前美方已決定對擬於10月1日實施的中國輸美商品加徵關稅措施做出調整,中方支援相關企業從即日起按照市場化原則和WTO規則,自美採購一定數量大豆、豬肉等農產品,國務院關稅稅則委員會將對上述採購予以加徵關稅排除。
抽取結果:['日前', '10月1日']
原文:據印度Zee新聞網站12日報道,亞洲新聞國際通訊社援引印度軍方訊息人士的話說,9月11日的對峙事件發生在靠近班公錯北岸的實際控制線一帶。
抽取結果:['12日', '9月11日']
原文:儋州市決定,從9月開始,對城市低保、農村低保、特困供養人員、優撫物件、領取失業保險金人員、建檔立卡未脫貧人口等低收入群體共3萬多人,發放豬肉價格補貼,每人每月發放不低於100元補貼,以後發放標準,將根據豬肉價波動情況進行動態調整。
抽取結果:['9月']
原文:9月11日,華為心聲社群釋出美國經濟學家托馬斯.弗裡德曼在《紐約時報》上的專欄內容,弗裡德曼透露,在與華為創始人任正非最近一次採訪中,任正非表示華為願意與美國司法部展開話題不設限的討論。
抽取結果:['9月11日']
原文:造血幹細胞移植治療白血病技術已日益成熟,然而,通過該方法同時治癒艾滋病目前還是一道全球尚在攻克的難題。
抽取結果:[]
原文:英國航空事故調查局(AAIB)近日披露,今年2月6日一趟由德國法蘭克福飛往墨西哥坎昆的航班上,因飛行員打翻咖啡使操作面板冒煙,導致飛機折返迫降愛爾蘭。
抽取結果:['近日', '今年2月6日']
原文:當地時間週四(9月12日),印度尼西亞財政部長英卓華(Sri Mulyani Indrawati)明確表示:特朗普的推特是風險之一。
抽取結果:['當地時間週四(9月12日)']
原文:華中科技大學9月12日通過其官方網站釋出通報稱,9月2日,我校一碩士研究生不幸墜樓身亡。
抽取結果:['9月12日', '9月2日']
原文:微博使用者@ooooviki 9月12日下午公佈發生在自己身上的驚悚遭遇:一個自稱網警、名叫鄭洋的人利用職務之便,查到她的完備的個人資訊,包括但不限於身份證號、家庭地址、電話號碼、戶籍變動情況等,要求她做他女朋友。
抽取結果:['9月12日下午']
原文:今天,貴陽取消了汽車限購,成為目前全國實行限購政策的9個省市中,首個取消限購的城市。
抽取結果:['今天', '目前']
原文:據悉,與全球同步,中國區此次將於9月13日於iPhone官方渠道和京東正式開啟預售,京東成Apple中國區唯一官方授權預售渠道。
抽取結果:['9月13日']
原文:根據央行公佈的資料,截至2019年6月末,存款類金融機構住戶部門短期消費貸款規模為9.11萬億元,2019年上半年該項淨增3293.19億元,上半年增量看起來並不樂觀。
抽取結果:['2019年6月末', '2019年上半年', '上半年']
原文:9月11日,一段拍攝浙江萬里學院學生食堂的視訊走紅網路,視訊顯示該學校食堂不僅在用餐區域設定了可以看電影、比賽的大螢幕,還推出了“一人食”餐位。
抽取結果:['9月11日']
原文:當日,在北京舉行的2019年國際籃聯籃球世界盃半決賽中,西班牙隊對陣澳大利亞隊。
抽取結果:['當日', '2019年']
一共耗時:0.5314s.

可以看到,對於測試的15個句子,識別的準確率很高,且預測耗時為531ms,平均每個話的預測時間不超過40ms。相比較而言,文章NLP(十七)利用tensorflow-serving部署kashgari模型中的模型,該模型的預測時間為每句話1秒多,模型預測的速度為帶ALBERT模型的25倍多。
  因此,ALBERT模型確實提升了模型預測的時間,而且效&果非常顯著。

總結

  由於ALBERT開源不到一週,而且筆者的學識、才能有限,因此,在程式碼方面可能會存在不足。但是,作為一次使用ALBERT的歷經,希望能夠與大家分享。
  本文絕不是上述專案程式碼的抄襲和堆砌,該專案融入了筆者自己的思考,希望不要被誤解為是抄襲。筆者使用上述的bertNER和ALBERT,只是為了驗證ALBERT在模型預測耗時方面的提速效果,而事實是,ALBERT確實給我帶來了很大驚喜,感受原始碼作者們~
  最後,附上本文中筆者專案的Github地址:https://github.com/percent4/ALBERT_4_Time_Recognition 。
  眾裡尋他千百度。驀然回首,那人卻在,燈火闌珊處。

參考文獻

  1. 超小型BERT中文版橫空出世!模型只有16M,訓練速度提升10倍:https://mp.weixin.qq.com/s/eVlNpejrxdE4ctDTBM-fiA
  2. ALBERT的Github地址:https://github.com/brightmart/albert_zh
  3. bertNER專案的Github地址:https://github.com/yumath/bertNER
  4. NLP(十七)利用tensorflow-serving部署kashgari模型: https://www.cnblogs.com/jclian91/p/11526547.html