Tensorflow(2)儲存模型與恢復
阿新 • • 發佈:2018-11-21
###一、資料模型的儲存
使用saver類,自動儲存tensorflow的圖結構(***.ckpt.meta),引數取值(***.ckpt.data),以及目錄下的檔案列表(***.ckpt.index),還有一個checkpoint檔案。
- 定義變數
- 變數操作
- 變數初始化
- 構建saver類
- 使用儲存模型引數到檔案
import tensorflow as tf v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1') v2=tf.Variable(tf.constant(4.0,shape=[1]),name='v2') v3=tf.Variable(tf.constant(4.0,shape=[1]),name='v3') result1=v1+v2 result2=result1+v3 init_op=tf.global_variables_initializer() saver=tf.train.Saver() with tf.Session() as sess: sess.run(init_op) saver.save(sess,"codes/tensorflow_test/model/model.ckpt")
###二、引數恢復
引數恢復程式中必須已經定義了引數,並且引數名稱要和定義的引數名字一致。可以使用tf.Variables(),tf.global_variables()獲取引數名字(適用於修改別人程式的時候)。
note:引數名稱必須一致(name=“ ”),具體的操作(result,result2)可以修改
for variables in tf.global_variables():
print(variables.name,variables.shape)
程式
import tensorflow as tf v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1') v2=tf.Variable(tf.constant(1.0,shape=[1]),name='v2') v3=tf.Variable(tf.constant(1.0,shape=[1]),name='v3') result1=v1+v2 result2=v1+v2+v3 saver=tf.train.Saver() #saver=tf.train.Saver([v1,v2] )#部分引數恢復,這個時候需要註釋v3,或者給v3加上初始化操作。 with tf.Session() as sess: saver.restore(sess,"codes/tensorflow_test/model/model.ckpt") # for variables in tf.global_variables(): # print(variables.name) print(sess.run(result1)) print(sess.run(result2)) print("sucessful\n")
###三、計算圖與引數同時恢復
- 計算圖和引數同時恢復的時候,不需要定義變數,也不需要變數初始化
- 變數在匯入圖結構之後就已經獲取了
- 可以使用原來的圖結構中的操作,這個時候需要指定運算名稱
- 也可以自定義新的操作
import tensorflow as tf saver=tf.train.import_meta_graph("/home/wuwei/codes/tensorflow_test/model/model.ckpt.meta") # for variables in tf.global_variables(): ###get name and shape # print(variables.name, variables.shape) result3=tf.get_default_graph().get_tensor_by_name("v1:0")+tf.get_default_graph().get_tensor_by_name("v2:0") with tf.Session() as sess: saver.restore(sess,"/home/wuwei/codes/tensorflow_test/model/model.ckpt") print(sess.run(result3)) ### our op print("**********************") print(sess.run(tf.get_default_graph().get_tensor_by_name("add_1:0")))###original op print("sucessful\n")
###四、另一種儲存tensorflow模型的操作,整個圖和引數設定為常量,儲存在一個檔案中
在使用tensorRT進行推理的時候,需要使用到這種模型。
##儲存模型save_model.py
import tensorflow as tf
from tensorflow.python.framework import graph_util
v1=tf.Variable(tf.constant(5.0,shape=[1],name="v1"))
v2=tf.Variable(tf.constant(4.0,shape=[1],name="v2"))
v3=tf.Variable(tf.constant(3.0,shape=[1],name="v3"))
result=v1+v2
print(result.name)
init_op=tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
graph_def=tf.get_default_graph().as_graph_def()
output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['add'])
with tf.gfile.GFile("/home/wuwei/codes/tensorflow_test/model/combined_model.pb",'wb') as f:
f.write(output_graph_def.SerializeToString())
##restore3.py
import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
model_filename="/home/wuwei/codes/tensorflow_test/model/combined_model.pb"
with gfile.FastGFile(model_filename,'rb') as f:
graph_def=tf.GraphDef()
graph_def.ParseFromString(f.read())
result=tf.import_graph_def(graph_def,return_elements=["add:0"])
print(sess.run(result))