tensorflow(四)caffe-tensorflow學習記錄
阿新 • • 發佈:2018-12-30
按照Lenet裡面的例子進行模型和網路的轉換:
LeNet Example
Thanks to @Russell91 for this example
This example showns you how to finetune code from the Caffe MNIST tutorial using Tensorflow.
First, you can convert a prototxt model to tensorflow code:
$ ./convert.py examples/mnist/lenet.prototxt --code-output-path=mynet.py
This produces tensorflow code for the LeNet network in mynet.py
. The code can be imported as described below in the Inference section. Caffe-tensorflow also lets you convert .caffemodel
weight files to .npy
files that can be directly loaded from tensorflow:
$ ./convert.py examples/mnist/lenet.prototxt --caffemodel examples/mnist/lenet_iter_10000.caffemodel --data-output-path=mynet.npy
The above command will generate a weight file named mynet.npy
.
Inference:
Once you have generated both the code weight files for LeNet, you can finetune LeNet using tensorflow with
$ ./examples/mnist/finetune_mnist.py
At a high level, finetune_mnist.py
works as follows:
# Import the converted model's class
from mynet import MyNet
# Create an instance, passing in the input data
net = MyNet({'data':my_input_data})
with tf.Session() as sesh:
# Load the data
net.load('mynet.npy', sesh)
# Forward pass
output = sesh.run(net.get_output(), ...)
經過轉換之後的程式碼:(自己新增的程式碼)
import numpy as np
from PIL import Image
import tensorflow as tf
import sys
sys.path.append('/home/yang/caffe-tensorflow')
sys.path.append('/home/yang/caffe-tensorflow/examples/yolo')
import yolo
image = tf.placeholder(tf.float32, [1,448,448,3])
net = yolo.yolo({'data': image})
image_path = '/home/yang/darknet/data/dog.jpg'
im = Image.open(image_path)
im_reshape = im.resize((448,448))
input = np.array(im_reshape)
input = input.reshape((1,448,448,3))
input = (input*1.0-127.5)*0.007874015748031496
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
net.load('/home/yang/caffe-tensorflow/examples/yolo/yolo.npy', sess)
output = sess.run(net.get_output(), feed_dict={image: input})
file = open('/home/yang/Desktop/result.txt','w')
for i in range(1470):
file.write(str(output[0][i])+' ')
file.close()
附加mnist測試程式碼:
import sys
sys.path.append('/home/yang/caffe-tensorflow/examples/mnist')
sys.path.append('/home/yang/tensorflow')
sys.path.append('/home/yang/caffe-tensorflow')
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('/home/yang/data', one_hot=True)
from mynet import LeNet as MyNet
image = tf.placeholder(tf.float32, [1, 784])
labels = tf.placeholder(tf.float32, [1, 10])
input = tf.reshape(image, shape=[-1, 28, 28, 1])
net = MyNet({'data': input})
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
net.load('/home/yang/caffe-tensorflow/examples/mnist/mynet.npy', sess)
batch_xs, batch_ys = mnist.train.next_batch(1)
output = sess.run(net.get_output(), feed_dict={image: batch_xs})