儲存和載入pb模型
阿新 • • 發佈:2019-02-16
將模型儲存為pb
import tensorflow as tf
from tensorflow.python.framework import graph_util
logdir='output/'
with tf.variable_scope('conv'):
w=tf.get_variable('w',[2,2],tf.float32,initializer=tf.random_normal_initializer)
b=tf.get_variable('b',[2],tf.float32,initializer=tf.random_normal_initializer)
sess=tf.InteractiveSession()
tf.global_variables_initializer().run() # 初始化所有變數
constant_graph_w = graph_util.convert_variables_to_constants(sess, sess.graph_def, ["conv/w"])
constant_graph_b = graph_util.convert_variables_to_constants(sess , sess.graph_def , ['conv/b'])
with tf.gfile.FastGFile(logdir+'expert_graph.pb', mode='wb') as f:
f.write(constant_graph_w.SerializeToString())
f.write (constant_graph_b.SerializeToString())
sess.close()
載入pb模型
import tensorflow as tf
from tensorflow.python.framework import graph_util
logdir = 'output/'
output_graph_path = logdir+'expert_graph.pb'
with tf.Session() as sess:
tf.global_variables_initializer().run()
output_graph_def = tf.GraphDef()
with open(output_graph_path, "rb") as f:
output_graph_def.ParseFromString(f.read())
_ = tf.import_graph_def(output_graph_def, name="")
w = sess.graph.get_tensor_by_name("conv/w:0")
print('w:' , w.eval())
b = sess.graph.get_tensor_by_name("conv/b:0")
print('b:' , b.eval())