TensorFlow學習筆記(5) TensorFlow模型持久化
阿新 • • 發佈:2018-12-12
TF提供了一個簡單的API來儲存和還原一個神經網路模型。這個API就是tf.train.Saver類。
下面即為儲存TensorFlow計算圖的方法(saver.save()):
這樣就實現了持久化一個簡單的TF模型的功能,通過saver.save函式將TF模型儲存到model.ckpt中。雖然只指定了一個檔案路徑,但是在該路徑下會出現三個檔案:import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1')) v2 = tf.Variable(tf.constant(1.0, shape=[1], name='v2')) result = v1 + v2 #宣告tf.train.Saver類用於儲存模型 saver = tf.train.Saver() with tf.Session() as sess: sess.run(tf.initialize_all_variables()) #將模型儲存到model/model.ckpt檔案 saver.save(sess, 'E:\PycharmProjects\\tensorflow_learn/model/model.ckpt')
model.ckpt.meta | 儲存了TF計算圖的結構 |
model.ckpt | 儲存了TF程式中每個變數的取值 |
checkpoint | 儲存了一個目錄下所有的模型檔案列表 |
import tensorflow as tf v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1')) v2 = tf.Variable(tf.constant(1.0, shape=[1], name='v2')) result = v1 + v2 #宣告tf.train.Saver saver = tf.train.Saver() with tf.Session() as sess: #載入已經儲存的模型,並通過已經儲存的模型中的變數來計算 saver.restore(sess, 'E:\PycharmProjects\\tensorflow_learn/model/model.ckpt') print(sess.run(result))
這兩段程式碼幾乎一樣,不同的地方在於載入模型的程式碼沒有初始化變數,而是通過已經儲存的模型載入進來。如果不想重複定義模型的結構,也可以直接將模型的結構加載出來:
為了儲存和載入部分變數,在宣告tf.train.Saver類時可以提供一個列表來指定儲存或者載入的變數,如 tf.train.Saver([v1]),這時就只有變數v1會被載入進來。除了可以指定被載入的變數,tf.train.Saver類也支援在儲存或載入時給變數重新命名:import tensorflow as tf #載入模型結構 saver = tf.train.import_meta_graph('E:\PycharmProjects\\tensorflow_learn/model/model.ckpt.meta') with tf.Session() as sess: saver.restore(sess, 'E:\PycharmProjects\\tensorflow_learn/model/model.ckpt') #通過張量的名稱來獲取張量 print(sess.run(tf.get_default_graph().get_tensor_by_name('add:0'))) #[2.]
import tensorflow as tf
v1 = tf.Variable(tf.constant(2.0, shape=[1], name='other-v1'))
v2 = tf.Variable(tf.constant(3.0, shape=[1], name='other-v2'))
#這裡如果直接使用tf.train.Saver()來載入模型會報錯
#使用一個字典來重新命名變數就可以載入原來的模型了
#這個字典指定了原名稱為v1的變數現在加到v1變數中
saver = tf.train.Saver({'v1': v1, 'v2': v2})
這種方式方便使用變數的滑動平均值,在載入模型時將影子變數對映到變數自身,那麼在訓練好的模型中就不需要再呼叫函式獲得變臉的滑動平均了:
import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name='v')
for variables in tf.global_variables():
print(variables.name)
#v:0
ema = tf.train.ExponentialMovingAverage(0.99)
maintain_averages_op = ema.apply(tf.global_variables())
for variables in tf.global_variables():
print(variables.name)
#v:0
#v/ExponentialMovingAverage:0
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.initialize_all_variables())
sess.run(tf.assign(v, 10))
sess.run(maintain_averages_op)
saver.save(sess, 'E:\PycharmProjects\\tensorflow_learn/model1/model.ckpt')
print(sess.run([v, ema.average(v)]))
#[10.0, 0.099999905]
import tensorflow as tf
v = tf.Variable(0, dtype=tf.float32, name='v')
saver = tf.train.Saver({'v/ExponentialMovingAverage': v})
with tf.Session() as sess:
saver.restore(sess, 'E:\PycharmProjects\\tensorflow_learn/model1/model.ckpt')
print(sess.run(v))
#0.099999905
為了方便載入時重新命名滑動變數,tf.train.ExponentialMovingAverage類提供了variables_to_restore來生成重新命名所需要的字典,即{'v/ExponentialMovingAverage': v},因此上面的程式碼也可以改為
saver = tf.train.Saver(ema.variables_to_restore())
在TF中,提供了convert_variables_to_constant函式將計算圖中的變數及其取值通過常量的方式儲存,這樣整個TF計算圖可以統一存放在一個檔案中:
import tensorflow as tf
from tensorflow.python.framework import graph_util
v1 = tf.Variable(tf.constant(1.0, shape=[1], name='v1'))
v2 = tf.Variable(tf.constant(2.0, shape=[1], name='v2'))
result = v1 + v2
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
#匯出當前計算圖GraphDef部分,只需這一層就可以完成從輸入層到輸出層的計算
graph_def = tf.get_default_graph().as_graph_def()
#將匯出的計算圖中的變數轉化為常量
output_graph_def = graph_util.convert_variables_to_constants(sess, graph_def, ['add'])
with tf.gfile.GFile('model/combined_model.pb', 'wb') as f:
f.write(output_graph_def.SerializeToString())
在儲存之後,當只需要得到計算圖中某個節點的取值時,就會有一個更方便的方法:
import tensorflow as tf
from tensorflow.python.framework import graph_util
with tf.Session() as sess:
model_filename = 'model/combined_model.pb'
#讀取檔案並解析成對應的GraphDef Protocol Buffer
with gfile.FastGfile(model_filename, 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
#將儲存的圖載入到當前圖中,並給定返回張量的名稱
result = tf.import_graph_def(graph_def, return_elements=['add:0'])#add為張量的名稱
print(sess.run(result))
源自:Tensorflow 實戰Google深度學習框架_鄭澤宇