TensorFlow Low-Level-APIs Save and Restore學習筆記
Save and Restore
本文翻譯自tensorflow官方網站的教程,只作為個人學習筆記,請勿用作商業用途。 tf.train.Saver類提供了儲存和提取模型的方法。tf.saved_model.simple_save函式也是一種簡單的方法來儲存模型。高階APIEstimaztors會自動的儲存和提取在model_dir中的模型。
Save and restore variables
TensorFlow Variables are the best way to represent shared, persistent state manipulated by your program.tf.train.Saver構造器為計算圖中的所有節點或者是指定的節點序列新增save和resotre的ops。然後Saver類去執行這些ops,並指定儲存和恢復的路徑來實現讀寫。
Saver類能從儲存的模型中讀取出計算圖中定義的Variable。如果你載入了一個模型但是不知道怎麼利用它來構建計算圖,可以參看Overview of saving and restoring models。TensorFlow通過二進位制的檔案來儲存Variable,並在其中做了Variable的name和值的對映。
Save variables
使用tf.train.Saver()來構建一個Saver以便控制整個模型,下例是一個使用Saver來儲存Variable的例子:
# Create some variables
v1 = tf.get_variable("v1", shape=[ 3], initializer=tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer)
inc_v1 = v1.assign(v1+1)
dec_v2 = v2.assign(v2-1)
# Add an op to initialize the variables
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables
saver = tf.train.Saver()
# Later, launh the model, initialize the variables, do some work, and save variables to disk
with tf.Session() as sess:
sess.run(init_op)
# Do some work with the model
inc_v1.op.run()
dec_v2.op.run()
save_path = saver.save(sess, "/temp/model.ckpt")
print("Model saved in path: %s" % save_path)
Restore variables
tf.train.Saver物件不僅儲存variable,還能將variable從檔案中讀取variable。注意當從檔案中讀取variable之前你不應該初始化他們,下面是一個例子:
tf.rest_default_graph()
# Create some variables
v1 = tf.get_variable("v1", shape=[3])
v2 = tf.get_variable("v2", shape=[5])
# Add ops to save and restore all the variables
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and do some work woth the model
with tf.Session() as sess:
# Restore variables from disk
saver.restore(sess, "/tmp/model.ckpt")
print("Model resotred")
# Check the values of the variables
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
Choose variables to save and restore
如果在構造tf.train.Saver()的時候沒有傳入任何引數,saver將會管理整個計算圖中的variable,每個variable儲存的時候使用的name就是構造variable時候的name。有時候給儲存的variable重新賦予name是有用的,比如在之前的計算圖中某個variable的name是weights,而你希望儲存的name是params。有時候也許只需要儲存計算圖中的一部分variable。可以通過向tf.train.Saver()傳入如下引數來實現這樣的目的:
- 一個variable的list(他們儲存的name就是構造他們時的name)
- 一個字典dict,key是儲存時使用的name,value是需要儲存的variable
使用之前的例子:
tf.rest_default_graph()
# Create some variables
v1 = tf.get_variable("v1", shape=[3], initializer=tf.zeros_initializer)
v2 = tf.get_variable("v2", shape=[5], initializer=tf.zeros_initializer)
# Add ops to save and restore only 'v2' using the name "v2"
saver = tf.train.Saver({"v2":v2})
# Use the saver object normally after that
with tf.Session() as sess:
# Initialize v1 since the saver will not
v1.initializer.run()
saver.restore(sess, "/tmp/model.ckpt")
print("v1 : %s" % v1.eval())
print("v2 : %s" % v2.eval())
注意事項:
- 如果需要對計算圖中不同的部分進行分別儲存可以構建多個Saver物件,同一個variable也可以通過多個Saver來儲存。
- 如果只想在session開始的時候讀取一部分的variable,除了這些讀取的variable其他的variable需要被初始化。
- 審查checkpoint中的variable,需要使用inspect_checkpoint庫,print_tensor_in_checkpoint_file函式
- 預設情況下Saver使用tf.Variable.name屬性來儲存每個variable,在建立Saver物件的時候也可以對checkpoint檔案中每個variable賦予name。
Inspect variables in a checkpoint
使用inspect_checkpoint庫來審查variable,仍然使用之前的例子:
# import the inspect_checkpoint library
from tensorflow.python.tools import inspect_checkpoint as chkp
# print all tensor in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='', all_tensor=True)
# tensor_name: v1
# [ 1. 1. 1.]
# tensor_name: v2
# [ -1. -1. -1. -1. -1.]
# print only tensor v1 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v1', all_tensor=False)
# tensor_name: v1
# [ 1. 1. 1.]
# print only tensor v2 in checkpoint file
chkp.print_tensors_in_checkpoint_file("/tmp/model.ckpt", tensor_name='v2', all_tensor=False)
# tensor_name: v2
# [ -1. -1. -1. -1. -1.]
Save and restore models
使用SaveModel可以儲存和載入模型包括variable、graph和graph的metadata。這種方法儲存模型更加自然,可以通過多個介面和SaveModel進行互動比如tf.saved_model或者高階APIs。
Build and load a SaveModel
Simple save
建立SaveModel最簡單的方法就是使用tf.save_model.simple_save函式:
simple_save(session,
export_dir,
inputs={"x":x, "y":y},
outputs={"z":z})
構造SaveModel的引數定義了輸入和輸出使得它可以被伺服器直接拿來進行使用,來預測或者是進行訓練。
Manually build a SaveModel
如果你的使用環境沒辦法使用tf.saved_model.simple_save,可以使用人工的builder APIs來構造一個SaveModel。tf.saved_model.builder.SaveMOdelBuilder類提供了儲存多個MetaGraphDef的功能。一個MetaGraph是一個數據流圖,再增加和它相關的variable、assets和signatures。一個MetaGraphDef是關於MetaGraph的一個protocol buffers(序列化標準)的表示。一個signature就是計算圖的輸入和輸出。如果assets需要被儲存或者寫或者複製到磁碟,那麼可以在熟悉新增MetaGraphDef時提供這些資源。如果多個MetaGraphDef與同名資源相關聯,則只保留首個版本。每個新增到SaveModel中的MetaGraphDef都必須使用使用者定義的tag進行區分。通過指定tag來確定要restore哪個MetaGraphDef,tag一般是表示用於訓練或者是用於推斷,或者指明執行在什麼裝置上如GPU。
export_dir = ...
...
builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
with tf.Session(graph=tf.Graph()) as sess:
...
builder.add_meta_graph_and variables(sess,
[tag_constants.TRAINING],
signature_def_map=foo_signatures,
assets_collection=foo_assets,
strip_default_attrs=True)
...
# Add a second MetaGraphDef for inference
with tf.Session(graph=tf.Graph()) as sess:
...
builder.add_meta_graph([tag_constants.SERVING], strip_default_attrs=True)
...
builder.save()
Forward compatibility via strip_deefault_attrs=True
如果計算圖中的ops沒有改變,那麼以下的教程讓你的模型restore以後具有前向傳播的能力。SaveModelBuilder類允許使用者來控制是否刪去ModeDefs中的一些default的屬性。SavedModelBuilder.add_meta_graph_and_variables和SavedModelBuilder.add_meta_graph方法都可以接受一個布林變數strip_default_attrs來控制是否刪掉。
如果strip_default_attrs是False,那麼輸出的tf.MetaGraphDef將要保留它其中的所有tf.NodeDef例項中的default的屬性。如下的一些情況也會導致這種前向傳播的能力失去,具體參考compatibility guidance
Loading a SavedModel in Python
Python版本的SavedModel的loader提供了儲存和載入模型的能力,load操作需要一下的資訊:
- 載入計算圖和variable對應的Session
- 定位MetaGraphDef的tag
- SavedModel的儲存路徑
在一次載入中variable,assets和signature都作為MetaGraphDef的一部分一起載入到了提供的session中。
export_dir = ...
...
with tf.Session(graph=tf.Graph()) as sess:
tf.saved_model.loader.load(sess, [tag_constants.TRAINING], export_dir)