pytorch模型轉trt部署
阿新 • • 發佈:2022-05-06
pytorch 轉onnx
首先載入pytorch模型
# load model import torch def load_model(ckpt) # build model model = build_model() # depending on your own model build function # load chpt checkpoint = torch.load(ckpt, map_location=torch.device('cpu')) model.load_state_dict(checkpoint["model_state"]) return model
使用torch.onnx將pytorch 模型轉為onnx
def export_onnx(model, onnx_name, batch_size): x, y = height, width img = torch.randn((batch_size, 3, x, y)).cuda() torch.onnx.export(model, img, onnx_name, export_params=True, opset_version=11, input_names=["input"], output_names=["output"], do_constant_folding=True, verbose=True )
onnx 轉 trt
首先要安裝tensorrt, 安裝教程可以參考link,之後可以選擇以下兩種方式進行轉換,1.是用trtexec命令 2.用python指令碼轉
- trtexec命令
trtexec --onnx=path/to/onnx --saveEngine=path/to/save/trt --explicitBatch --fp16 --workspace=15000
如果提示trtexec command not found, 找到你的tensorrt安裝目錄,例如/usr/local/tensorrt, 將上述中的trtexec替換為/usr/local/tensorrt/bin/trtexec,如果嫌麻煩的話可以在~/.bashrc
中新增下邊一句
alias trtexec="/usr/local/tensorrt/bin/trtexec"
儲存退出然後source ~/.bashrc就可以使用trtexec命令了
- python指令碼
TRT_LOGGER = trt.Logger(trt.Logger.INFO)
EXPLICIT_BATCH = 1 << (int)(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
def get_engine(onnx_file_path, engine_file_path, using_half):
"""Attempts to load a serialized engine if available, otherwise builds a new TensorRT engine and saves it."""
def build_engine():
device = torch.device('cuda:{}'.format(0))
"""Takes an ONNX file and creates a TensorRT engine to run inference with"""
with trt.Builder(TRT_LOGGER) as builder, \
builder.create_network(EXPLICIT_BATCH) as network, \
trt.OnnxParser(network, TRT_LOGGER) as parser:
config = builder.create_builder_config()
config.max_workspace_size = 1 << 30
if using_half:
config.set_flag(trt.BuilderFlag.FP16)
# Parse model file
if not os.path.exists(onnx_file_path):
print('ONNX file {} not found, please first to generate it.'.format(onnx_file_path))
exit(0)
with open(onnx_file_path, 'rb') as model:
print('Beginning ONNX file parsing')
parser.parse(model.read())
with torch.cuda.device(device):
engine = builder.build_engine(network, config)
assert engine is not None, 'Failed to create TensorRT engine'
with open(engine_file_path, "wb") as f:
f.write(engine.serialize())
return engine
if os.path.exists(engine_file_path):
# If a serialized engine exists, use it instead of building an engine.
print("Reading engine from file {}".format(engine_file_path))
with open(engine_file_path, "rb") as f, trt.Runtime(TRT_LOGGER) as runtime:
return runtime.deserialize_cuda_engine(f.read())
else:
return build_engine()
if __name__ == '__main__':
batch_size = 1 # only works for TRT. perf reported by torch is working on non-batched data.
using_half = True
model_name = 'your_model_name'
model_path = 'path/to/pth'
onnx_path = '{name}.onnx'.format(name=model_name)
with torch.no_grad():
model = load_model(model_path)
export_onnx(model, onnx_path, batch_size)
engine = get_engine(onnx_path,
'{name}.trt'.format(name=model_name),
using_half)
加速前處理一張圖片大約50ms,加速後的推理速度位10ms