java呼叫tensorflow訓練好的模型
阿新 • • 發佈:2018-12-19
1. python的處理
整個模型的原始碼在此:https://github.com/shelleyHLX/tensorflow_java
多謝star
首先訓練一個模型,程式碼如下
import tensorflow as tf import numpy as np import matplotlib.pyplot as plt from tensorflow.python.framework import graph_util ## -1到1之間隨機數 100個 train_X = np.linspace(-1, 1, 100) train_Y = 2*train_X + np.random.randn(*train_X.shape)*0.1 # 顯示模擬資料點 plt.plot(train_X, train_Y, 'ro', label='test') plt.legend() plt.show() # 建立模型 # 佔位符 X = tf.placeholder("float",name='X') Y = tf.placeholder("float",name='Y') # 模型引數 # W初始化為-1到1之間的一個數字 W = tf.Variable(tf.random_normal([1]), name="weight") # b初始化為0 也是一維 定義變數 b = tf.Variable(tf.zeros([1]), name="bias") # 前向結構 mulpiply兩個數 相乘 z = tf.multiply(X, W) + b op = tf.add(tf.multiply(X, W),b,name='results') # 反向優化 cost = tf.reduce_mean(tf.square(Y - z)) learning_rate = 0.01 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # 初始化所有變數 init = tf.global_variables_initializer() # 定義引數 training_epochs = 20 display_step = 2 def moving_avage(a, w=10): if len(a) < w: return a[:] return [val if idx<w else sum(a[(idx-w):idx])/w for idx, val in enumerate(a)] saver = tf.train.Saver() # 啟動session with tf.Session() as sess: sess.run(init) # 存放批次值和損失值 plotdata = {"batchsize": [], "loss": []} # 向量模型輸入資料 for epoch in range(training_epochs): for(x, y) in zip(train_X, train_Y): sess.run(optimizer, {X:x, Y:y}) # 顯示訓練中的詳細資訊 if epoch % display_step == 0: loss = sess.run(cost, {X:train_X, Y:train_Y}) print("Epoch:", epoch+1, "cost=", loss, "W=", sess.run(W), "b=",sess.run(b)) if not (loss == "NA"): plotdata["batchsize"].append(epoch) plotdata["loss"].append(loss) print("Finished!") #儲存模型 saver.save(sess, "model/first") print("cost =", sess.run(cost, feed_dict={X:train_X, Y:train_Y}), "W=", sess.run(W), "b=", sess.run(b)) const_graph = graph_util.convert_variables_to_constants(sess, sess.graph_def,["results"]) with tf.gfile.FastGFile("model/first.pb",mode='wb') as f: f.write(const_graph.SerializeToString()) # 圖形顯示 plt.plot(train_X, train_Y, 'ro', label='Original data') plt.plot(train_X, sess.run(W)*train_X+sess.run(b),label='Filttedline') plt.legend() plt.show() plotdata["avgloss"] = moving_avage(plotdata["loss"]) # plt.figure(1) plt.subplot(211) plt.plot(plotdata["batchsize"],plotdata["avgloss"], 'b--') plt.xlabel('Minibatch number') plt.ylabel('Loss') plt.title('Minibatch run vs, Trainging loss') plt.show() print("x=0.2, z=", sess.run(z, {X:0.2}))
測試模型:
from tensorflow.python.platform import gfile import tensorflow as tf sess = tf.Session() with gfile.FastGFile('model/first.pb','rb') as f: graph_def = tf.GraphDef() graph_def.ParseFromString(f.read()) sess.graph.as_default() tf.import_graph_def(graph_def,name='') sess.run(tf.global_variables_initializer()) print(sess.run('weight:0')) print(sess.run('bias:0')) input_x = sess.graph.get_tensor_by_name('X:0') op = sess.graph.get_tensor_by_name('results:0') ret = sess.run(op, feed_dict={input_x: 2}) print(ret)
2 java的處理
新建一個maven專案
把模型加入專案中.
在pom.xml設定tensorflow,第一次使用會下載.
在xin/src/test/java/com.xin.tf_java.xin新建一個java類:abcd.java
內容如下:
package com.xin.tf_java.xin; import java.awt.image.BufferedImage; import java.io.File; import java.io.FileInputStream; import java.io.FileNotFoundException; import java.io.IOException; import java.nio.charset.Charset; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; import java.util.ArrayList; import java.util.Comparator; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.PriorityQueue; import org.apache.commons.io.IOUtils; import javax.imageio.ImageIO; import org.tensorflow.Graph; import org.tensorflow.Operation; import org.tensorflow.Output; import org.tensorflow.Session; import org.tensorflow.Shape; import org.tensorflow.Tensor; import org.apache.commons.io.IOUtils; public class abcd { public static void main(String[] args) throws FileNotFoundException, IOException { // TODO Auto-generated method stub try (Graph graph = new Graph()) { byte[] graphBytes = IOUtils.toByteArray(new FileInputStream("model/first.pb")); graph.importGraphDef(graphBytes); try (Session session = new Session(graph)) { Tensor<?> out = session.runner().feed("X", Tensor.create(2.0f)).fetch("results").run().get(0); float[] r = new float[1]; out.copyTo(r); System.out.println(r[0]); } } } }
要把commons-io-2.6.jar加入;下載位置:http://commons.apache.org/proper/commons-io/download_io.cgi
change project compliance and jre to 1.7照做就可以
右鍵執行
reference: