tensorflow載入與儲存模型
阿新 • • 發佈:2021-01-12
技術標籤:DeepLearning
問題:
1.訓練好分類模型,比如訓練儲存了一個10分類的模型,但是實際用的時候呢,分類數可能會改變,但是還想繼續使用前面儲存的模型。那麼相當於是隻載入前幾層的引數,最後一層做一些修改。
2.載入預訓練模型時,預訓練模型缺少網路中定義的變數
儲存模型
saver = tf.train.Saver()
saver.save(sess,“model.ckpt”)
載入模型
saver.restore(sess,“model.ckpt”)
不傳引數時,相當於是儲存了所有的引數,然後載入所有的引數。
載入模型時變數缺失情況
我們可以先將模型定義的變數輸出看一下,得到變數資訊方式有很多
from tensorflow.contrib.slim.nets import resnet_v1
slim = tensorflow.contrib.slim
inputdata = tf.placeholder(tf.float32, shape=(1, 224, 224, 3), name='input')
net, end_points = resnet_v1.resnet_v1_50(inputdata, 1000, is_training=False)
我們在構建好網路後,載入網路中定義的變數
# way1
variables_to_resotre = tf.global_variables()
# way2
variables_to_resotre = slim.get_variables_to_restore()
輸出這些變數的型別,可發現這些變數為列表型別。
那麼我們只需要刪除掉我們不需要的元素即可。
使用元素的name屬性可得到其名稱
刪除樣例如下:
#得到該網路中,所有可以載入的引數
variables = tf.slim.get_variables_to_restore()
#刪除output層中的引數
variables_to_resotre = [v for v in varialbes if v.name !='output' ]
#構建這部分引數的saver
saver = tf.train.Saver(variables_to_restore)
saver.restore(sess,'model.ckpt')