Tensorflow模型儲存和過載
最近因為專案要求,需要把模型的訓練和測試過程分開,這裡主要涉及兩個過程:訓練圖的存取和引數的存取。
以下所有/home/yy/xiajbxie/model是我的模型的儲存路徑,將其換成你自己的即可。
tf.train.Saver()
Saver的作用中文社群已經講得相當清楚。tf.train.Saver()類的基本操作時save()和restore()函式,分別負責模型引數的儲存和恢復。引數儲存示例如下:
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1" )
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
# Add an op to initialize the variables.
init_op = tf.initialize_all_variables()
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# initialize the variables, save the variables to disk.
with tf.Session() as sess:
sess.run(init_op)
v1, v2 = sess.run([v1, v2])
print(v1)
print(v2)
# Do some work with the model.
# Save the variables to disk.
save_path = saver.save(sess, "/home/yy/xiajbxie/model")
print "Model saved in file: ", save_path
執行結果:
[[-0.0493206 0.12752049]]
[[ 1.9456626 0.6319563 -0.1296857 ]
[-0.7834143 0.33656874 -0.96077037]]
Model saved in file: /home/yy/xiajbxie/model
引數恢復示例如下:
import tensorflow as tf
# Create some variables.
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
# Add ops to save and restore all the variables.
saver = tf.train.Saver()
# Later, launch the model, use the saver to restore variables from disk, and
# do some work with the model.
with tf.Session() as sess:
# Restore variables from disk.
saver.restore(sess, "/home/yy/xiajbxie/model")
print "Model restored."
print(sess.run([v1, v2]))
執行結果:
Model restored.
[array([[-0.0493206 , 0.12752049]], dtype=float32), array([[ 1.9456626 , 0.6319563 , -0.1296857 ],
[-0.7834143 , 0.33656874, -0.96077037]], dtype=float32)]
saver.save()
函式的引數為需儲存的會話,以及模型的儲存路徑。儲存後我們進入模型的儲存路徑會看到4個新增檔案,4個檔案根據tensorflow版本不同名字不同,以上例為例,1.2版本4個檔案如下:
1. checkpoint
:其中儲存模型所在的路徑
2. model.meta
:包含計算圖的完整資訊
3. model.index
:與下面的檔案一起儲存所有的變數值
4. model.data-00000-of-00001
可以看到,在模型引數恢復前需事先定義要恢復的變數,並且變數名需要與模型中儲存的變數名保持一致。
官方文件的說法是無需在引數恢復前對其進行初始化,但實際操作的時候有出現過報錯“FailedPreconditionError (see above for traceback): Attempting to use uninitialized value”的情況,此時利用tf.global_variables_initializer()
初始化變數可解決問題。
tf.train.import_meta_graph()
模型引數恢復之前需要先定義模型中儲存的變數,如果不想這樣做可以把模型的計算圖也恢復出來。tf.train.import_meta_graph()
函式就用於恢復模型,它的輸入引數為模型路徑,返回一個Saver類例項,再呼叫這個例項的restore()函式就可以恢復其引數了。示例如下:
import tensorflow as tf
sess = tf.Session()
new_saver = tf.train.import_meta_graph('/home/yy/xiajbxie/model.meta')
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state('/home/yy/xiajbxie')
if ckpt and ckpt.model_checkpoint_path:
print ckpt.model_checkpoint_path
new_saver.restore(sess, ckpt.model_checkpoint_path)
v1 = tf.get_default_graph().get_tensor_by_name('v1:0')
v2 = tf.get_default_graph().get_tensor_by_name('v2:0')
print(sess.run([v1, v2]))
執行結果:
/home/yy/xiajbxie/model
[array([[-0.0493206 , 0.12752049]], dtype=float32), array([[ 1.9456626 , 0.6319563 , -0.1296857 ],
[-0.7834143 , 0.33656874, -0.96077037]], dtype=float32)]
其中get_checkpoint_state()用於在傳入的路徑中尋找tensorflow檢查點。
tips
- 在不知道要過載的tensor叫什麼名字時可以在訓練階段列印變數名來觀察。
- 不能在與訓練資料相同的計算圖下載入以前儲存的計算圖,如果實在要這樣做也要保證兩個計算圖中不包含名字相同的變數。
- 利用tf.Graph()來生成新的計算圖,利用tf.Graph().as_default()來將新生成的計算圖設定為預設。