16、TensorFLow 模型引數的儲存與恢復
阿新 • • 發佈:2019-02-07
最簡單的儲存和恢復模型的方法是使用
tf.train.Saver()
物件,它給graph
中的所有變數,或是定義在列表裡的變數,新增save
和restore ops
。tf.train.Saver()
物件提供了方法來執行這些ops
,並指定了檢查點檔案的讀寫路徑。
一、tf.train.Saver() 類解析
tf.train.Saver(
var_list=None,
reshape=False,
sharded=False,
max_to_keep=5,
keep_checkpoint_every_n_hours=10000.0,
name=None ,
restore_sequentially=False,
saver_def=None,
builder=None,
defer_build=False,
allow_empty=False,
write_version=tf.train.SaverDef.V2,
pad_step_number=False,
save_relative_paths=False,
filename=None
)
1、初始化引數解析
- var_list
- specifies the variables that will be saved and restored. If None, defaults to the list of all saveable objects. It can be passed as a dict or a list
- A dict of names to variables: The keys are the names that will be used to save or restore the variables in the checkpoint files.
- A list of variables: The variables will be keyed with their op name in the checkpoint files.
- specifies the variables that will be saved and restored. If None, defaults to the list of all saveable objects. It can be passed as a dict or a list
- For example:
v1 = tf.Variable(..., name='v1')
v2 = tf.Variable(..., name='v2' )
# 1、 pass them as a list,可使用此 list 儲存或載入部分變數
saver = tf.train.Saver([v1, v2])
# 2、Pass the variables as a dict:
saver = tf.train.Saver({'v1': v1, 'v2': v2})
# 3、Passing a list is equivalent to passing a dict with the variable op names as keys:
saver = tf.train.Saver({v.op.name: v for v in [v1, v2]})
# 4、儲存或載入時給變數重新命名
v1 = tf.Variable(..., name='other_v1')
v2 = tf.Variable(..., name='other_v2')
saver = tf.train.Saver({'v1': v1, 'v2': v2})
print(v1.name) # 輸出:other-v1:0
- max_to_keep
- indicates the maximum number of recent checkpoint files to keep.
- As new files are created, older files are deleted.
- If None or 0, all checkpoint files are kept.
Defaults to 5
(that is, the 5 most recent checkpoint files are kept.) - 設定
max_to_keep=1
則只儲存最新的model
,或者在使用save()
方法儲存模型時,保持global_step=None
也可以達到只儲存最新model
的效果。
2、常用方法解析
# Returns a string, path at which the variables were saved.
save(
sess,
save_path,
global_step=None,
latest_filename=None,
meta_graph_suffix='meta',
write_meta_graph=True,
write_state=True
)
# The variables to restore do not have to have been initialized, as restoring is itself a way to initialize variables.
restore(
sess,
save_path
)
二、引數的儲存與恢復
1、檢查點檔案介紹
- 變數儲存在二進位制檔案裡,主要包含從
variable names to tensor values
的對映關係- 當你建立一個
Saver物件
時,你可以選擇性地為檢查點檔案中的變數挑選變數名。預設情況下,將使用每個變數tf.Variable.name 屬性的值
。(這才是模型的引數,和變數名沒有半毛錢關係)saver = tf.train.Saver(max_to_keep=3)
時 checkpoint 儲存的檔案詳情如下:
- 第一個檔案儲存了一個目錄下所有
模型檔案路徑
的列表- 第二個檔案儲存了我們的模型(all the values of the weights, biases, gradients and all the other variables saved)
- 第三個檔案為索引
- 第四個檔案為計算圖的結構,包括:all variables, operations, collections etc
2、儲存變數&恢復變數
- 可以用一個
bool
型變數is_train
來控制訓練和驗證
兩個階段,True
表示訓練,False
表示測試tf.train.Saver()
類支援在恢復變數時給變數重新命名(改寫原來變數中的name
引數)
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import tensorflow as tf
# Create some variables.
w = tf.get_variable("weight", shape=[2], initializer=tf.zeros_initializer())
b = tf.get_variable("bias", shape=[3], initializer=tf.zeros_initializer())
inc_w = w.assign(w + 1)
dec_b = b.assign(b - 1)
# Add an op to initialize the variables.
init_op = tf.global_variables_initializer()
# Add ops to save and restore all the variables.
saver = tf.train.Saver(max_to_keep=3)
isTrain = False # True 表示訓練,False 表示測試
train_steps = 1000
checkpoint_steps = 50
checkpoint_dir = 'checkpoint/save&restore/'
model_name = 'my_model'
# Later, launch the model, initialize the variables, do some work, and save the
# variables to disk.
with tf.Session() as sess:
sess.run(init_op)
if isTrain:
# Do some work with the model.
for step in range(train_steps):
inc_w.op.run()
dec_b.op.run()
if (step + 1) % checkpoint_steps == 0:
# Append the step number to the checkpoint name:
saved_path = saver.save(
sess,
checkpoint_dir + model_name,
global_step=step + 1 # 設為 None 時,只儲存最新結果
)
else:
print('Before restore:')
print(sess.run(w))
print(sess.run(b))
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
# 獲取最新的 model_file
if ckpt and ckpt.model_checkpoint_path:
print("Success to load %s." % ckpt.model_checkpoint_path)
saver.restore(sess, ckpt.model_checkpoint_path)
else:
pass
print('After restore:')
print(sess.run(w))
print(sess.run(b))
# 測試結果
Before restore:
[ 0. 0.]
[ 0. 0. 0.]
Success to load checkpoint/save&restore/my_model-1000.
After restore:
[ 1000. 1000.]
[-1000. -1000. -1000.]
# 結論:restore 其實就相當於重新初始化所有的變數
# 結論分析
雖然官方文件說:restore 時不用使用 init_op 去初始化所有的變量了,但這裡為了驗證下(restore 其實就相當於重新初始化所有的變數),還是把 sess.run(init_op) 放在了if isTrain: 語句的上面(同時作用於訓練和測試階段), 從測試結果中可以驗證結論。
# 其實可以把 sess.run(init_op) 放在 if isTrain: 語句的裡面(只作用於訓練階段)
3、取得可訓練引數的值&提取某一層的特徵
sess = tf.Session()
# Returns all variables created with trainable=True in a var_list
var_list = tf.trainable_variables()
print("Trainable variables:------------------------")
# 取出所有可訓練引數的索引、形狀和名稱
for idx, v in enumerate(var_list):
print("param {:3}: {:15} {}".format(idx, str(v.get_shape()), v.name))
# 某網路輸出示例
Trainable variables:------------------------
param 0: (5, 5, 3, 32) conv2d/kernel:0
param 1: (32,) conv2d/bias:0
param 2: (5, 5, 32, 64) conv2d_1/kernel:0
param 3: (64,) conv2d_1/bias:0
param 4: (3, 3, 64, 128) conv2d_2/kernel:0
param 5: (128,) conv2d_2/bias:0
param 6: (3, 3, 128, 128) conv2d_3/kernel:0
param 7: (128,) conv2d_3/bias:0
param 8: (4608, 1024) dense/kernel:0
param 9: (1024,) dense/bias:0
param 10: (1024, 512) dense_1/kernel:0 --->dense2 層的引數
param 11: (512,) dense_1/bias:0
param 12: (512, 5) dense_2/kernel:0
param 13: (5,) dense_2/bias:0
# 提取最後一個全連線層的引數 W 和 b
W = sess.run(var_list[12])
b = sess.run(var_list[13])
# 提取第二個全連線層的輸出值作為特徵
feature = sess.run(dense2, feed_dict={x:img})
三、繼續訓練&Fine-tune 某一層
1、繼續訓練(所有引數)
# 定義一個全域性物件來獲取引數的值,在程式中使用(eg:FLAGS.iteration)來引用引數
FLAGS = tf.app.flags.FLAGS
# 定義命令列引數,第一個是:引數名稱,第二個是:引數預設值,第三個是:引數描述
tf.app.flags.DEFINE_string(
"checkpoint_dir",
"/path/to/checkpoint_save_dir/",
"Directory name to save the checkpoints [checkpoint]"
)
tf.app.flags.DEFINE_boolean(
"continue_train",
False,
"True for continue training.[False]"
)
saver = tf.train.Saver()
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
if FLAGS.continue_train:
# 自動取得最新的 model_file
model_file = tf.train.latest_checkpoint(checkpoint_dir)
saver.restore(sess, model_file)
print("Success to load %s." % model_file)
2、Fine-tune 某一層
- 更改網路中權重和偏置的引數,把需要固定不進行訓練的變數的
trainable
引數設定為False
- 然後再使用上面的程式碼進行繼續訓練即可
eg:my_non_trainable = tf.get_variable("my_non_trainable", shape=(3, 3), trainable=False)
- Restore a meta checkpoint(待總結?????)
- use the TF helper
tf.train.import_meta_graph()