tensorflow 恢復部分引數、載入指定引數
阿新 • • 發佈:2019-01-10
多分類採用與訓練模型輸出不匹配,我們需要載入部分預訓練模型的引數。
我們先看一下如何儲存和讀入預訓練模型。
#一般實驗情況下儲存的時候,都是用的saver類來儲存,如下 saver = tf.train.Saver() saver.save(sess,"model.ckpt") #載入時的程式碼 saver.restore(sess,"model.ckpt") #前面的描述相當於是儲存了所有的引數,然後載入所有的引數。 #但是目前的情況有所變化了,不能載入所有的引數,最後一層的引數不一樣了,需要隨機初始化。 #首先對每一層新增name scope,如下: with name_scope('conv1'): xxx with name_scope('conv2'): xxx with name_scope('fc1'): xxx with name_scope('output'): xxx #然後根據變數的名字,選擇載入哪些變數, #得到該網路中,所有可以載入的引數 variables = tf.contrib.framework.get_variables_to_restore() #刪除output層中的引數 variables_to_resotre = [v for v in varialbes if v.name.split('/')[0]!='output'] #構建這部分引數的 saversaver = tf.train.Saver(variables_to_restore) saver.restore(sess,'model.ckpt') #在tensorflow中,有多種方式可以得到變數的資訊: tf.contrib.framework.get_variables_to_restore() tf.all_variables()tf.trainable_varialbes()
多分類採用與訓練模型輸出不匹配解決方法:
利用tf.contrib.framework.get_variables_to_restore()函式,程式碼如下
variables_to_restore = tf.contrib.framework.get_variables_to_restore(exclude=['resnet50/fc']) saver = tf.train.Saver(variables_to_restore) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess, param_path)
exclude=['resnet50/fc']表示載入預訓練引數中除了resnet50/fc這一層之外的其他所有引數。
include=["inceptionv3"]表示只加載inceptionv3這一層的所有引數。
param_path是你預訓練引數儲存地址。
注:如果不止一個層引數需要丟棄,exclue=['a', 'b']即可。調優訓練(fine_tuning)時最好把前面曾trainable設為False,只訓練最後一層。