tensorflow自定義op_TensorRT加速tensorflow模型
技術標籤:tensorflow自定義op
一. 動機
目前最新版本的TensorRT已經支援tensorflow1.x和tensorflow2.x版本的模型,由於個人之前的模型是基於tf1.14訓練的,為了進一步對模型進行加速,因此本文主要對基於1.14的模型使用TensorRT進行加速。
二. Tensorflow的優勢
目前tensorflow中已經繼承了tensorrt模組,因此只要有savedmodel就可以用TensorRT進行加速,而不需要像Pytorch需要轉格式之後再用TensorRT加速,同時基於tensorflow內部的tensorrt模組,可以避免寫plugin來支援對應的運算元,當一個op無法被TensorRT進行轉換時,模型會仍舊使用tensorflow原有的運算元,需要進一步加速時才需要新增自定義運算元,當對應的op不是十分耗時的情況時,可以使用tensorflow的運算元進行計算,轉換後依舊可以被儲存為saved_
三.轉換
首先需要有一個轉換好的saved_model(可以見上一篇講述tensorflow-serving的文章https://zhuanlan.zhihu.com/p/104960285), 如以下格式
import tensorflow as tf import cv2 import base64 from tensorflow.python.compiler.tensorrt import trt_convert as trt saved_model_dir = "./export_model_0126/1581080318" output_saved_model_dir = "./convert_INT8_export_model" fetch_names = ["strided_slice_256:0", "cond/Merge:0", "strided_slice_1:0", "ExpandDims:0", "Const_39:0", "strided_slice_258:0", "Shape:0", "Cast:0", "strided_slice_260:0", "combined_non_max_suppression/CombinedNonMaxSuppression:3", "cond/Merge_1:0", "strided_slice_2:0"] class feed_dict_input_fn(): def __init__(self, filename): self.filename = filename self.content = [] with open(self.filename) as f: for line in f: self.content.append(line.strip()) self.index = 0 def __call__(self, *args, **kwargs): data = open(self.content[self.index], 'rb').read() encode = base64.urlsafe_b64encode(data) encode = str(encode, encoding='utf-8') image = {"input:0": encode} # value = {"inputs": image} self.index += 1 return image converter = trt.TrtGraphConverter(input_saved_model_dir=saved_model_dir, precision_mode=trt.TrtPrecisionMode.INT8, use_calibration=True, is_dynamic_op=True, maximum_cached_engines=3) feet_dict_input = feed_dict_input_fn("/home/admin-seu/TempData/sss/Master_work/data/test.list") converter.convert() converter.calibrate(fetch_names=fetch_names, num_runs=100, feed_dict_fn=feet_dict_input) converter.save(output_saved_model_dir)
上述程式碼塊使用了tensorflow內部的tensorrt模組來對模型加速,其中的fetch_names同樣可以參考上一篇文章中的saved_model_cli工具獲取模型的輸出tensor的名字,calibrate函式是用一串輸入資料集對模型進行校準,這是由於TensorRT的INT8需要對資料進行歸一化,因此校準是必要的。轉換完成後,會得到下圖中的INT8模型,同樣可以對應生成FP16和FP32的模型。
四.測試
使用如下程式碼塊進行簡單測試:
port tensorflow as tf import numpy as np import base64 import time # output_saved_model_dir = "./convert_export_model" # output_saved_model_dir = "./export_model_0126/1581080318" output_saved_model_dir = "./convert_INT8_export_model" # output_saved_model_dir = "./convert_FP32_export_model" data = open("/home/admin-seu/TempData/test2017/000000258074.jpg", 'rb').read() encode = base64.urlsafe_b64encode(data) encode = str(encode, encoding='utf-8') with tf.Session() as sess: tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING], output_saved_model_dir) cur_graph = sess.graph node_names = [tensor.name for tensor in sess.graph_def.node] output_tensors = [] input_tensor = cur_graph.get_tensor_by_name("input:0") output_tensors.append(cur_graph.get_tensor_by_name("strided_slice_256:0")) output_tensors.append(cur_graph.get_tensor_by_name("strided_slice_260:0")) # for node_name in node_names: # if "input" in node_name: # print(node_name) # input_tensor = cur_graph.get_tensor_by_name(node_name) # if "scores_1" in node_name: # print(node_name) # output_tensors.append(cur_graph.get_tensor_by_name(node_name)) # if "labels" in node_name: # print(node_name) # output_tensors.append(cur_graph.get_tensor_by_name(node_name)) # if "all_ids" in node_name: # print(node_name) # output_tensors.append(cur_graph.get_tensor_by_name(node_name)) # if "boxes_1" in node_name: # print(node_name) # output_tensors.append(cur_graph.get_tensor_by_name(node_name)) output = sess.run(output_tensors, feed_dict={input_tensor: encode}) print(np.shape(output[0])) for i in range(10): output = sess.run(output_tensors, feed_dict={input_tensor: encode}) start = time.time() print(start) for i in range(100): output = sess.run(output_tensors, feed_dict={input_tensor: encode}) end = time.time() print(end) print(end - start)
最終可以得到加速後的模型的速度,由於本文使用的檢測模型使用了大量TensorRT不支援的運算元,因此加速效果比較有限,大概能比原先模型提升10%的速度。在一些更為簡單的任務上相信模型能得到更大的加速比。
5.總結
可見使用TensorRT對tensorflow的模型加速是十分簡單的,基本開箱即用,程式碼已放置在https://github.com/smallsunsun1/Cascade-RCNN,最後總結一下就是Tensorflow Yes, 後續可能記錄一下鴿了很久的對TensorFlow原始碼部分的閱讀筆記了只剩,立一個Flag,希望有時間有空自己補上0.0