1. 程式人生 > 程式設計 >Pytorch轉tflite方式

Pytorch轉tflite方式

目標是想把在伺服器上用pytorch訓練好的模型轉換為可以在移動端執行的tflite模型。

最直接的思路是想把pytorch模型轉換為tensorflow的模型,然後轉換為tflite。但是這個轉換目前沒有發現比較靠譜的方法。

經過調研發現最新的tflite已經支援直接從keras模型的轉換,所以可以採用keras作為中間轉換的橋樑,這樣就能充分利用keras高層API的便利性。

轉換的基本思想就是用pytorch中的各層網路的權重取出來後直接賦值給keras網路中的對應layer層的權重。

轉換為Keras模型後,再通過tf.contrib.lite.TocoConverter把模型直接轉為tflite.

下面是一個例子,假設轉換的是一個兩層的CNN網路。

import tensorflow as tf
from tensorflow import keras
import numpy as np

import torch
from torchvision import models
import torch.nn as nn
# import torch.nn.functional as F
from torch.autograd import Variable

class PytorchNet(nn.Module):
 def __init__(self):
 super(PytorchNet,self).__init__()
 conv1 = nn.Sequential(
  nn.Conv2d(3,32,3,2),nn.BatchNorm2d(32),nn.ReLU(inplace=True),nn.MaxPool2d(2,2))
 conv2 = nn.Sequential(
  nn.Conv2d(32,64,1,groups=1),nn.BatchNorm2d(64),2))
 self.feature = nn.Sequential(conv1,conv2)
 self.init_weights()

 def forward(self,x):
 return self.feature(x)

 def init_weights(self):
 for m in self.modules():
  if isinstance(m,nn.Conv2d):
  nn.init.kaiming_normal_(
   m.weight.data,mode='fan_out',nonlinearity='relu')
  if m.bias is not None:
   m.bias.data.zero_()
  if isinstance(m,nn.BatchNorm2d):
  m.weight.data.fill_(1)
  m.bias.data.zero_()

def KerasNet(input_shape=(224,224,3)):
 image_input = keras.layers.Input(shape=input_shape)
 # conv1
 network = keras.layers.Conv2D(
 32,(3,3),strides=(2,padding="valid")(image_input)
 network = keras.layers.BatchNormalization(
 trainable=False,fused=False)(network)
 network = keras.layers.Activation("relu")(network)
 network = keras.layers.MaxPool2D(pool_size=(2,2))(network)

 # conv2
 network = keras.layers.Conv2D(
 64,strides=(1,1),padding="valid")(network)
 network = keras.layers.BatchNormalization(
 trainable=False,fused=True)(network)
 network = keras.layers.Activation("relu")(network)
 network = keras.layers.MaxPool2D(pool_size=(2,2))(network)

 model = keras.Model(inputs=image_input,outputs=network)

 return model

class PytorchToKeras(object):
 def __init__(self,pModel,kModel):
 super(PytorchToKeras,self)
 self.__source_layers = []
 self.__target_layers = []
 self.pModel = pModel
 self.kModel = kModel
 tf.keras.backend.set_learning_phase(0)

 def __retrieve_k_layers(self):
 for i,layer in enumerate(self.kModel.layers):
  if len(layer.weights) > 0:
  self.__target_layers.append(i)

 def __retrieve_p_layers(self,input_size):

 input = torch.randn(input_size)
 input = Variable(input.unsqueeze(0))
 hooks = []

 def add_hooks(module):

  def hook(module,input,output):
  if hasattr(module,"weight"):
   # print(module)
   self.__source_layers.append(module)

  if not isinstance(module,nn.ModuleList) and not isinstance(module,nn.Sequential) and module != self.pModel:
  hooks.append(module.register_forward_hook(hook))

 self.pModel.apply(add_hooks)

 self.pModel(input)
 for hook in hooks:
  hook.remove()

 def convert(self,input_size):
 self.__retrieve_k_layers()
 self.__retrieve_p_layers(input_size)

 for i,(source_layer,target_layer) in enumerate(zip(self.__source_layers,self.__target_layers)):
  print(source_layer)
  weight_size = len(source_layer.weight.data.size())
  transpose_dims = []
  for i in range(weight_size):
  transpose_dims.append(weight_size - i - 1)
  if isinstance(source_layer,nn.Conv2d):
  transpose_dims = [2,0]
  self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(
  ).transpose(transpose_dims),source_layer.bias.data.numpy()])
  elif isinstance(source_layer,nn.BatchNorm2d):
  self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(),source_layer.bias.data.numpy(),source_layer.running_mean.data.numpy(),source_layer.running_var.data.numpy()])
 def save_model(self,output_file):
 self.kModel.save(output_file)

 def save_weights(self,output_file):
 self.kModel.save_weights(output_file,save_format='h5')

pytorch_model = PytorchNet()
keras_model = KerasNet(input_shape=(224,3))

torch.save(pytorch_model,'test.pth')

#Load the pretrained model
pytorch_model = torch.load('test.pth')

# #Time to transfer weights
converter = PytorchToKeras(pytorch_model,keras_model)
converter.convert((3,224))

# #Save the converted keras model for later use
# converter.save_weights("keras.h5")
converter.save_model("keras_model.h5")

# convert keras model to tflite model
converter = tf.contrib.lite.TocoConverter.from_keras_model_file(
 "keras_model.h5")
tflite_model = converter.convert()
open("convert_model.tflite","wb").write(tflite_model)

補充知識:tensorflow模型轉換成tensorflow lite模型

1.把graph和網路模型打包在一個檔案中

 bazel build tensorflow/python/tools:freeze_graph && \
 bazel-bin/tensorflow/python/tools/freeze_graph \
 --input_graph=eval_graph_def.pb \
 --input_checkpoint=checkpoint \
 --output_graph=frozen_eval_graph.pb \
 --output_node_names=outputs

For example:

 bazel-bin/tensorflow/python/tools/freeze_graph \ 
 --input_graph=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_eval.pbtxt \
 --input_checkpoint=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt \
 --output_graph=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \
 --output_node_names=MobilenetV1/Predictions/Reshape_1

2.把第一步中生成的tensorflow pb模型轉換為tf lite模型

轉換前需要先編譯轉換工具

bazel build tensorflow/contrib/lite/toco:toco

轉換分兩種,一種的轉換為float的tf lite,另一種可以轉換為對模型進行unit8的量化版本的模型。兩種方式如下:

非量化的轉換:

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ 官網給的這個路徑不對       
./bazel-bin/tensorflow/contrib/lite/toco/toco \         
 —input_file=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \  
 —output_file=./mobilenet_v1_1.0_224/tflite_model_test.tflite \  
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \       
 --inference_type=FLOAT \           
 --input_shape="1,3" \           
 --input_array=input \            
 --output_array=MobilenetV1/Predictions/Reshape_1 

量化方式的轉換(注意,只有量化訓練的模型才能進行量化的tf_lite轉換):

./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \
./bazel-bin/tensorflow/contrib/lite/toco/toco \
 --input_file=frozen_eval_graph.pb \
 --output_file=tflite_model.tflite \
 --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
 --inference_type=QUANTIZED_UINT8 \
 --input_shape="1,3" \
 --input_array=input \
 --output_array=outputs \
 --std_value=127.5 --mean_value=127.5

以上這篇Pytorch轉tflite方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。