1. 程式人生 > 其它 >pytorch轉onnx驗證_端側部署好助手:pytorch 模型轉 onnx,並驗證結果

pytorch轉onnx驗證_端側部署好助手:pytorch 模型轉 onnx,並驗證結果

技術標籤:pytorch轉onnx驗證

ONNX(Open Neural Network Exchange)是一種針對機器學習所設計的開放式的檔案格式,用於儲存訓練好的模型。它使得不同的人工智慧框架(如Pytorch、MXNet)可以採用相同格式儲存模型資料並互動。ONNX的規範及程式碼主要由微軟,亞馬遜,Facebook和IBM等公司共同開發,以開放原始碼的方式託管在Github上。[1][2][3]目前官方支援載入ONNX模型並進行推理的深度學習框架有:Caffe2, PyTorch, MXNet,ML.NET,TensorRT 和 Microsoft CNTK,並且 TensorFlow 也非官方的支援ONNX。

pytorch 也是我們經常用到的深度學習框架,有些時候比如端側部署的時候我們需要將pytorch 轉成 onnx ,然後轉成 ncnn 、mcnn這些端側推理框架能使用的模型,下面教大家怎麼把 pytorch 模型轉成 onnx 模型。

為了方便,我們使用 torchvision 裡的預先定義好的模型 resnet18.

import torchimport torchvisiondummy_input = torch.randn(10, 3, 224, 224,)model = torchvision.models.resnet18(pretrained=True)# 為輸入輸出起個名字input_names = [ "input_node" ]output_names = [ "output" ]torch.onnx.export(model, dummy_input, "resnet18.onnx", verbose=True, input_names=input_names, output_names=output_names)

這樣,我們能成功匯出 onnx 模型。resnet18.onnx是二進位制的 protobuf 檔案,包含了模型的結構和引數,這裡邊有一個關鍵字是verbose=True,會打印出人類可讀性的網路表達:

graph(%input_node : Float(10, 3, 224, 224),
%conv1.weight : Float(64, 3, 7, 7),
%bn1.weight : Float(64),
%bn1.bias : Float(64),
%bn1.running_mean : Float(64),
%bn1.running_var : Float(64),
%bn1.num_batches_tracked : Long(),
%layer1.0.conv1.weight : Float(64, 64, 3, 3),... 篇幅原因略 %169 : Float(10, 256, 14, 14) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%168, %layer3.1.conv2.weight), scope: ResNet/Sequential[layer3]/BasicBlock[1]/Conv2d[conv2] # torch/nn/modules/conv.py:342:0
%170 : Float(10, 256, 14, 14) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%169, %layer3.1.bn2.weight, %layer3.1.bn2.bias, %layer3.1.bn2.running_mean, %layer3.1.bn2.running_var), scope: ResNet/Sequential[layer3]/BasicBlock[1]/BatchNorm2d[bn2] # torch/nn/functional.py:1670:0
%171 : Float(10, 256, 14, 14) = onnx::Add(%170, %165), scope: ResNet/Sequential[layer3]/BasicBlock[1] # torchvision/models/resnet.py:69:0
%172 : Float(10, 256, 14, 14) = onnx::Relu(%171), scope: ResNet/Sequential[layer3]/BasicBlock[1]/ReLU[relu] # torch/nn/functional.py:912:0
%173 : Float(10, 512, 7, 7) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[2, 2]](%172, %layer4.0.conv1.weight), scope: ResNet/Sequential[layer4]/BasicBlock[0]/Conv2d[conv1] # torch/nn/modules/conv.py:342:0
%174 : Float(10, 512, 7, 7) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%173, %layer4.0.bn1.weight, %layer4.0.bn1.bias, %layer4.0.bn1.running_mean, %layer4.0.bn1.running_var), scope: ResNet/Sequential[layer4]/BasicBlock[0]/BatchNorm2d[bn1] # torch/nn/functional.py:1670:0
%175 : Float(10, 512, 7, 7) = onnx::Relu(%174), scope: ResNet/Sequential[layer4]/BasicBlock[0]/ReLU[relu] # torch/nn/functional.py:912:0
%176 : Float(10, 512, 7, 7) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%175, %layer4.0.conv2.weight), scope: ResNet/Sequential[layer4]/BasicBlock[0]/Conv2d[conv2] # torch/nn/modules/conv.py:342:0
%177 : Float(10, 512, 7, 7) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%176, %layer4.0.bn2.weight, %layer4.0.bn2.bias, %layer4.0.bn2.running_mean, %layer4.0.bn2.running_var), scope: ResNet/Sequential[layer4]/BasicBlock[0]/BatchNorm2d[bn2] # torch/nn/functional.py:1670:0
%178 : Float(10, 512, 7, 7) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[1, 1], pads=[0, 0, 0, 0], strides=[2, 2]](%172, %layer4.0.downsample.0.weight), scope: ResNet/Sequential[layer4]/BasicBlock[0]/Sequential[downsample]/Conv2d[0] # torch/nn/modules/conv.py:342:0
%179 : Float(10, 512, 7, 7) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%178, %layer4.0.downsample.1.weight, %layer4.0.downsample.1.bias, %layer4.0.downsample.1.running_mean, %layer4.0.downsample.1.running_var), scope: ResNet/Sequential[layer4]/BasicBlock[0]/Sequential[downsample]/BatchNorm2d[1] # torch/nn/functional.py:1670:0
%180 : Float(10, 512, 7, 7) = onnx::Add(%177, %179), scope: ResNet/Sequential[layer4]/BasicBlock[0] # torchvision/models/resnet.py:69:0
%181 : Float(10, 512, 7, 7) = onnx::Relu(%180), scope: ResNet/Sequential[layer4]/BasicBlock[0]/ReLU[relu] # torch/nn/functional.py:912:0
%182 : Float(10, 512, 7, 7) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%181, %layer4.1.conv1.weight), scope: ResNet/Sequential[layer4]/BasicBlock[1]/Conv2d[conv1] # torch/nn/modules/conv.py:342:0
%183 : Float(10, 512, 7, 7) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%182, %layer4.1.bn1.weight, %layer4.1.bn1.bias, %layer4.1.bn1.running_mean, %layer4.1.bn1.running_var), scope: ResNet/Sequential[layer4]/BasicBlock[1]/BatchNorm2d[bn1] # torch/nn/functional.py:1670:0
%184 : Float(10, 512, 7, 7) = onnx::Relu(%183), scope: ResNet/Sequential[layer4]/BasicBlock[1]/ReLU[relu] # torch/nn/functional.py:912:0
%185 : Float(10, 512, 7, 7) = onnx::Conv[dilations=[1, 1], group=1, kernel_shape=[3, 3], pads=[1, 1, 1, 1], strides=[1, 1]](%184, %layer4.1.conv2.weight), scope: ResNet/Sequential[layer4]/BasicBlock[1]/Conv2d[conv2] # torch/nn/modules/conv.py:342:0
%186 : Float(10, 512, 7, 7) = onnx::BatchNormalization[epsilon=1e-05, momentum=0.9](%185, %layer4.1.bn2.weight, %layer4.1.bn2.bias, %layer4.1.bn2.running_mean, %layer4.1.bn2.running_var), scope: ResNet/Sequential[layer4]/BasicBlock[1]/BatchNorm2d[bn2] # torch/nn/functional.py:1670:0
%187 : Float(10, 512, 7, 7) = onnx::Add(%186, %181), scope: ResNet/Sequential[layer4]/BasicBlock[1] # torchvision/models/resnet.py:69:0
%188 : Float(10, 512, 7, 7) = onnx::Relu(%187), scope: ResNet/Sequential[layer4]/BasicBlock[1]/ReLU[relu] # torch/nn/functional.py:912:0
%189 : Float(10, 512, 1, 1) = onnx::GlobalAveragePool(%188), scope: ResNet/AdaptiveAvgPool2d[avgpool] # torch/nn/functional.py:768:0
%190 : Float(10, 512) = onnx::Flatten[axis=1](%189), scope: ResNet # torchvision/models/resnet.py:207:0
%output : Float(10, 1000) = onnx::Gemm[alpha=1, beta=1, transB=1](%190, %fc.weight, %fc.bias), scope: ResNet/Linear[fc] # torch/nn/functional.py:1370:0
return (%output)

為了驗證模型的正確性,我們需要安裝下 onnx ,使用 onnx 執行時來驗證

pip install onnx

import onnxruntime as ortort_session = ort.InferenceSession('resnet18.onnx')outputs = ort_session.run(None, {'input_node': np.random.randn(10, 3, 224, 224).astype(np.float32)})print(outputs[0])

這時只需要把onnx和pytorch模型的輸出列印對比即可確認,注意模型的輸入要完全一致,即預處理等完全一致。

關注我們,一起學習

a7838c19db6e1796f5b025c75065e8b2.png