tensorflow 模型預訓練後的引數restore finetuning
之前訓練的網路中有一部分可以用到一個新的網路中,但是不知道儲存的引數如何部分恢復到新的網路中,也瞭解到有許多網路是通過利用一些現有的網路結構,通過finetuning進行改造實現的,因此瞭解了一下關於模型預訓練後部分引數restore和finetuning的內容
更多內容參見:
https://blog.csdn.net/mieleizhi0522/article/details/80535189
https://blog.csdn.net/leo_xu06/article/details/79200634
https://blog.csdn.net/b876144622/article/details/79962727
首先了解一下變數(tf.Variable),變數是tf框架中用於儲存引數的物件,我們這裡要恢復的引數也是variable型別的。訓練的引數是放在不同名字下的variable中的,checkpoint中儲存的變數也是通過不同的名字進行區分的,這裡如果要恢復指定的引數可以使用
with tf.variable_scope('', reuse = True): sess.run(tf.get_variable(your_var_name).assign(reader.get_tensor(pretrained_var_name)))
Saver是用於儲存變數的物件。下面是saver物件的建立和呼叫
saver = tf.train.Saver()
save_path = saver.save(sess, "/tmp/model.ckpt")
如果僅在session開始時恢復模型變數的一個子集,需要對剩下的變數執行初始化op。
# Create some variables. v1 = tf.Variable(..., name="v1") v2 = tf.Variable(..., name="v2") ... # Add ops to save and restore only 'v2' using the name "my_v2" saver = tf.train.Saver({"my_v2": v2})
對已有checkpoint內容進行檢視,可以使用一下程式碼(來自https://blog.csdn.net/mieleizhi0522/article/details/80535189),然後就可以結合之前的指定變數名的方法對引數進行restore了。注意,在完成部分引數的restore後要記得對沒有初始化的變數進行初始化,否則報錯。
import tensorflow as tf
import os
from tensorflow.python import pywrap_tensorflow
model_dir=r'G:\KeTi\C3D'
checkpoint_path = os.path.join(model_dir, "sports1m_finetuning_ucf101.model")
# 從checkpoint中讀出資料
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
# reader = tf.train.NewCheckpointReader(checkpoint_path) # 用tf.train中的NewCheckpointReader方法
var_to_shape_map = reader.get_variable_to_shape_map()
# 輸出權重tensor名字和值
for key in var_to_shape_map:
print("tensor_name: ", key,reader.get_tensor(key).shape)
輸出
tensor_name: var_name/wc4a (3, 3, 3, 256, 512)
tensor_name: var_name/wc3a (3, 3, 3, 128, 256)
tensor_name: var_name/wd1 (8192, 4096)
tensor_name: var_name/wc5b (3, 3, 3, 512, 512)
tensor_name: var_name/bd1 (4096,)
tensor_name: var_name/wd2 (4096, 4096)
tensor_name: var_name/wout (4096, 101)
tensor_name: var_name/wc1 (3, 3, 3, 3, 64)
tensor_name: var_name/bc4b (512,)
tensor_name: var_name/wc2 (3, 3, 3, 64, 128)
tensor_name: var_name/bc3a (256,)
tensor_name: var_name/bd2 (4096,)
tensor_name: var_name/bc5a (512,)
tensor_name: var_name/bc2 (128,)
tensor_name: var_name/bc5b (512,)
tensor_name: var_name/bout (101,)
tensor_name: var_name/bc4a (512,)
tensor_name: var_name/bc3b (256,)
tensor_name: var_name/wc4b (3, 3, 3, 512, 512)
tensor_name: var_name/bc1 (64,)
tensor_name: var_name/wc3b (3, 3, 3, 256, 256)
tensor_name: var_name/wc5a (3, 3, 3, 512, 512)