1. 程式人生 > >Tensorflow(2)儲存模型與恢復

Tensorflow(2)儲存模型與恢復

###一、資料模型的儲存
使用saver類,自動儲存tensorflow的圖結構(***.ckpt.meta),引數取值(***.ckpt.data),以及目錄下的檔案列表(***.ckpt.index),還有一個checkpoint檔案。

  • 定義變數
  • 變數操作
  • 變數初始化
  • 構建saver類
  • 使用儲存模型引數到檔案
import tensorflow as tf

v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
v2=tf.Variable(tf.constant(4.0,shape=[1]),name='v2')
v3=tf.Variable(tf.constant(4.0,shape=[1]),name='v3')

result1=v1+v2
result2=result1+v3

init_op=tf.global_variables_initializer()
saver=tf.train.Saver()

with tf.Session() as sess:
	sess.run(init_op)
	saver.save(sess,"codes/tensorflow_test/model/model.ckpt")

###二、引數恢復
引數恢復程式中必須已經定義了引數,並且引數名稱要和定義的引數名字一致。可以使用tf.Variables(),tf.global_variables()獲取引數名字(適用於修改別人程式的時候)。
note:引數名稱必須一致(name=“ ”),具體的操作(result,result2)可以修改

for variables in tf.global_variables():
	print(variables.name,variables.shape)

程式

import tensorflow as tf

v1=tf.Variable(tf.constant(1.0,shape=[1]),name='v1')
v2=tf.Variable(tf.constant(1.0,shape=[1]),name='v2')
v3=tf.Variable(tf.constant(1.0,shape=[1]),name='v3')

result1=v1+v2
result2=v1+v2+v3

saver=tf.train.Saver()
#saver=tf.train.Saver([v1,v2] )#部分引數恢復,這個時候需要註釋v3,或者給v3加上初始化操作。
with tf.Session() as sess:
	saver.restore(sess,"codes/tensorflow_test/model/model.ckpt")
	# for variables in tf.global_variables():
	# 	print(variables.name)
	print(sess.run(result1))
	print(sess.run(result2))
	print("sucessful\n")

###三、計算圖與引數同時恢復

  • 計算圖和引數同時恢復的時候,不需要定義變數,也不需要變數初始化
  • 變數在匯入圖結構之後就已經獲取了
  • 可以使用原來的圖結構中的操作,這個時候需要指定運算名稱
  • 也可以自定義新的操作
import tensorflow as tf

saver=tf.train.import_meta_graph("/home/wuwei/codes/tensorflow_test/model/model.ckpt.meta")
# for variables in tf.global_variables():  ###get name and shape
#     print(variables.name, variables.shape)
result3=tf.get_default_graph().get_tensor_by_name("v1:0")+tf.get_default_graph().get_tensor_by_name("v2:0")
with tf.Session() as sess:
    saver.restore(sess,"/home/wuwei/codes/tensorflow_test/model/model.ckpt")
    print(sess.run(result3))  ### our op
    print("**********************")
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add_1:0")))###original op
    print("sucessful\n")

###四、另一種儲存tensorflow模型的操作,整個圖和引數設定為常量,儲存在一個檔案中
在使用tensorRT進行推理的時候,需要使用到這種模型。

##儲存模型save_model.py

import tensorflow as tf
from tensorflow.python.framework import graph_util


v1=tf.Variable(tf.constant(5.0,shape=[1],name="v1"))
v2=tf.Variable(tf.constant(4.0,shape=[1],name="v2"))
v3=tf.Variable(tf.constant(3.0,shape=[1],name="v3"))

result=v1+v2
print(result.name)
init_op=tf.global_variables_initializer()

with tf.Session() as sess:
    sess.run(init_op)
    graph_def=tf.get_default_graph().as_graph_def()

    output_graph_def=graph_util.convert_variables_to_constants(sess,graph_def,['add'])
    with tf.gfile.GFile("/home/wuwei/codes/tensorflow_test/model/combined_model.pb",'wb') as f:
        f.write(output_graph_def.SerializeToString()) 

##restore3.py

import tensorflow as tf
from tensorflow.python.platform import gfile
with tf.Session() as sess:
    model_filename="/home/wuwei/codes/tensorflow_test/model/combined_model.pb"
    with gfile.FastGFile(model_filename,'rb') as f:
        graph_def=tf.GraphDef()
        graph_def.ParseFromString(f.read())
    result=tf.import_graph_def(graph_def,return_elements=["add:0"])
    print(sess.run(result))