1. 程式人生 > >onnx與tensorflow格式的相互轉換

onnx與tensorflow格式的相互轉換

onnx是Facebook打造的AI中介軟體,但是Tensorflow官方不支援onnx,所以只能用onnx自己提供的方式從tensorflow嘗試轉換

Tensorflow模型轉onnx

Tensorflow轉onnx, onnx官方github上有提供轉換的方式,地址為https://github.com/onnx/tutorials/blob/master/tutorials/OnnxTensorflowExport.ipynb 。按連結中的步驟一步一步就能完成mnist的模型轉換,我也成功轉換出了mnist.onnx模型。但是在上面步驟中model = onnx.load('mnist.onnx')

之後執行tf_rep = prepare(model)一直不成功。但是換成網上別人用pytorch轉的mnist.onnx執行tf_rep = prepare(model)又完全是OK的,這個暫時還沒找到原因在哪裡。

onnx模型轉換為Tensorflow模型

上面提到按官網的教程從tensorflow轉換生成的onnx模型執行tf_rep = prepare(model)有問題。所以這裡我從網上下載的一個pytorch轉換的mnist onnx模型為實驗物件,實驗用的onnx下載地址:https://download.csdn.net/download/computerme/10448754
onnx模型轉換為Tensorflow模型的程式碼如下:

import onnx
import numpy as np
from onnx_tf.backend import prepare

model = onnx.load('./assets/mnist_model.onnx')
tf_rep = prepare(model)

img = np.load("./assets/image.npz")
output = tf_rep.run(img.reshape([1, 1,28,28]))

print("outpu mat: \n",output)
print("The digit is classified as ", np.argmax(output))

import
tensorflow as tf with tf.Session() as persisted_sess: print("load graph") persisted_sess.graph.as_default() tf.import_graph_def(tf_rep.predict_net.graph.as_graph_def(), name='') inp = persisted_sess.graph.get_tensor_by_name( tf_rep.predict_net.tensor_dict[tf_rep.predict_net.external_input[0]].name ) out = persisted_sess.graph.get_tensor_by_name( tf_rep.predict_net.tensor_dict[tf_rep.predict_net.external_output[0]].name ) res = persisted_sess.run(out, {inp: img.reshape([1, 1,28,28])}) print(res) print("The digit is classified as ",np.argmax(res)) tf_rep.export_graph('tf.pb')

轉換完成後,需要對轉換出的tf.pb模型進行驗證,驗證方式如下:

import numpy as np
import tensorflow as tf
from tensorflow.python.platform import gfile

name = "tf.pb"

with tf.Session() as persisted_sess:
    print("load graph")
    with gfile.FastGFile(name, 'rb') as f:
        graph_def = tf.GraphDef()
        graph_def.ParseFromString(f.read())

    persisted_sess.graph.as_default()
    tf.import_graph_def(graph_def, name='')

    inp = persisted_sess.graph.get_tensor_by_name('0:0')
    out = persisted_sess.graph.get_tensor_by_name('LogSoftmax:0')
    #test = np.random.rand(1, 1, 28, 28).astype(np.float32)
    #feed_dict = {inp: test}

    img = np.load("./assets/image.npz")
    feed_dict = {inp: img.reshape([1, 1,28,28])}

    classification = persisted_sess.run(out, feed_dict)
    print(out)
    print(classification)