1. 程式人生 > >tensorflow從已經訓練好的模型中,恢復(指定)權重

tensorflow從已經訓練好的模型中,恢復(指定)權重

https://blog.csdn.net/AManFromEarth/article/details/79155926

https://blog.csdn.net/ying86615791/article/details/76215363

https://blog.csdn.net/ying86615791/article/details/76215363

遷移學習的實現需要網路在其他資料集上做預訓練,完成引數調優工作,然後拿預訓練好的引數在新的任務上做fine-tune,但是有時候可能只需要預訓練的網路的一部分權重,本文主要提供一個方法如何在tf上載入想要載入的權重。

在使用tensorflow載入網路權重的時候,直接使用tf.train.Saver().restore(sess, ‘ckpt’)的話是直接載入了全部權重,我們可能只需要載入網路的前幾層權重,或者只要或者不要特定幾層的權重,這時可以使用下面的方法:

var = tf.global_variables()
var_to_restore = [val  for val in var if 'conv1' in val.name or 'conv2'in val.name]
saver = tf.train.Saver(var_to_restore )
saver.restore(sess, os.path.join(model_dir, model_name))
var_to_init = [val  for val in var if 'conv1' not in val.name or 'conv2'not in val.name]
tf.initialize_variables(var_to_init)

這樣就只從ckpt檔案裡只讀取到了兩層卷積的卷積引數,前提是你的前兩層網路結構和名字和ckpt檔案裡定義的一樣。將var_to_restore和var_to_init反過來就是載入名字中不包含conv1、2的權重。

如果使用tensorflow的slim選擇性讀取權重的話就更方便了

exclude = ['layer1', 'layer2']
variables_to_restore = slim.get_variables_to_restore(exclude=exclude)
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess, os.path.join
(model_dir, model_name))

這樣就完成了不讀取ckpt檔案中’layer1’, ‘layer2’權重