1. 程式人生 > 其它 >tensorflow自定義op_TensorRT加速tensorflow模型

tensorflow自定義op_TensorRT加速tensorflow模型

技術標籤:tensorflow自定義op

一. 動機

目前最新版本的TensorRT已經支援tensorflow1.x和tensorflow2.x版本的模型,由於個人之前的模型是基於tf1.14訓練的,為了進一步對模型進行加速,因此本文主要對基於1.14的模型使用TensorRT進行加速。

二. Tensorflow的優勢

目前tensorflow中已經繼承了tensorrt模組,因此只要有savedmodel就可以用TensorRT進行加速,而不需要像Pytorch需要轉格式之後再用TensorRT加速,同時基於tensorflow內部的tensorrt模組,可以避免寫plugin來支援對應的運算元,當一個op無法被TensorRT進行轉換時,模型會仍舊使用tensorflow原有的運算元,需要進一步加速時才需要新增自定義運算元,當對應的op不是十分耗時的情況時,可以使用tensorflow的運算元進行計算,轉換後依舊可以被儲存為saved_

model格式,從而使用tfserving進行部署。

三.轉換

首先需要有一個轉換好的saved_model(可以見上一篇講述tensorflow-serving的文章https://zhuanlan.zhihu.com/p/104960285), 如以下格式

0cf8c554e92b7d37bd8c86278ae9ed44.png
import tensorflow as tf
import cv2
import base64
from tensorflow.python.compiler.tensorrt import trt_convert as trt

saved_model_dir = "./export_model_0126/1581080318"
output_saved_model_dir = "./convert_INT8_export_model"
fetch_names = ["strided_slice_256:0", "cond/Merge:0", "strided_slice_1:0", "ExpandDims:0",
               "Const_39:0", "strided_slice_258:0", "Shape:0", "Cast:0",
               "strided_slice_260:0", "combined_non_max_suppression/CombinedNonMaxSuppression:3",
               "cond/Merge_1:0", "strided_slice_2:0"]

class feed_dict_input_fn():
    def __init__(self, filename):
        self.filename = filename
        self.content = []
        with open(self.filename) as f:
            for line in f:
                self.content.append(line.strip())
        self.index = 0
    def __call__(self, *args, **kwargs):
        data = open(self.content[self.index], 'rb').read()
        encode = base64.urlsafe_b64encode(data)
        encode = str(encode, encoding='utf-8')
        image = {"input:0": encode}
        # value = {"inputs": image}
        self.index += 1
        return image


converter = trt.TrtGraphConverter(input_saved_model_dir=saved_model_dir,
                                  precision_mode=trt.TrtPrecisionMode.INT8,
                                  use_calibration=True, is_dynamic_op=True, maximum_cached_engines=3)
feet_dict_input = feed_dict_input_fn("/home/admin-seu/TempData/sss/Master_work/data/test.list")
converter.convert()
converter.calibrate(fetch_names=fetch_names, num_runs=100, feed_dict_fn=feet_dict_input)
converter.save(output_saved_model_dir)

上述程式碼塊使用了tensorflow內部的tensorrt模組來對模型加速,其中的fetch_names同樣可以參考上一篇文章中的saved_model_cli工具獲取模型的輸出tensor的名字,calibrate函式是用一串輸入資料集對模型進行校準,這是由於TensorRT的INT8需要對資料進行歸一化,因此校準是必要的。轉換完成後,會得到下圖中的INT8模型,同樣可以對應生成FP16和FP32的模型。

7e946d73407507168df738489444674e.png

四.測試

使用如下程式碼塊進行簡單測試:

port tensorflow as tf
import numpy as np
import base64
import time

# output_saved_model_dir = "./convert_export_model"
# output_saved_model_dir = "./export_model_0126/1581080318"
output_saved_model_dir = "./convert_INT8_export_model"
# output_saved_model_dir = "./convert_FP32_export_model"

data = open("/home/admin-seu/TempData/test2017/000000258074.jpg", 'rb').read()
encode = base64.urlsafe_b64encode(data)
encode = str(encode, encoding='utf-8')

with tf.Session() as sess:
    tf.saved_model.loader.load(sess, [tf.saved_model.tag_constants.SERVING],
                               output_saved_model_dir)
    cur_graph = sess.graph
    node_names = [tensor.name for tensor in sess.graph_def.node]
    output_tensors = []
    input_tensor = cur_graph.get_tensor_by_name("input:0")
    output_tensors.append(cur_graph.get_tensor_by_name("strided_slice_256:0"))
    output_tensors.append(cur_graph.get_tensor_by_name("strided_slice_260:0"))
    # for node_name in node_names:
    #     if "input" in node_name:
    #         print(node_name)
    #         input_tensor = cur_graph.get_tensor_by_name(node_name)
    #     if "scores_1" in node_name:
    #         print(node_name)
    #         output_tensors.append(cur_graph.get_tensor_by_name(node_name))
    #     if "labels" in node_name:
    #         print(node_name)
    #         output_tensors.append(cur_graph.get_tensor_by_name(node_name))
    #     if "all_ids" in node_name:
    #         print(node_name)
    #         output_tensors.append(cur_graph.get_tensor_by_name(node_name))
    #     if "boxes_1" in node_name:
    #         print(node_name)
    #         output_tensors.append(cur_graph.get_tensor_by_name(node_name))
    output = sess.run(output_tensors, feed_dict={input_tensor: encode})
    print(np.shape(output[0]))
    for i in range(10):
        output = sess.run(output_tensors, feed_dict={input_tensor: encode})
    start = time.time()
    print(start)
    for i in range(100):
        output = sess.run(output_tensors, feed_dict={input_tensor: encode})
    end = time.time()
    print(end)
    print(end - start)

最終可以得到加速後的模型的速度,由於本文使用的檢測模型使用了大量TensorRT不支援的運算元,因此加速效果比較有限,大概能比原先模型提升10%的速度。在一些更為簡單的任務上相信模型能得到更大的加速比。

5.總結

可見使用TensorRT對tensorflow的模型加速是十分簡單的,基本開箱即用,程式碼已放置在https://github.com/smallsunsun1/Cascade-RCNN,最後總結一下就是Tensorflow Yes, 後續可能記錄一下鴿了很久的對TensorFlow原始碼部分的閱讀筆記了只剩,立一個Flag,希望有時間有空自己補上0.0