1. 程式人生 > 其它 >tensorflow載入與儲存模型

tensorflow載入與儲存模型

技術標籤:DeepLearning

問題:

1.訓練好分類模型,比如訓練儲存了一個10分類的模型,但是實際用的時候呢,分類數可能會改變,但是還想繼續使用前面儲存的模型。那麼相當於是隻載入前幾層的引數,最後一層做一些修改。
2.載入預訓練模型時,預訓練模型缺少網路中定義的變數

儲存模型

saver = tf.train.Saver()
saver.save(sess,“model.ckpt”)

載入模型

saver.restore(sess,“model.ckpt”)

不傳引數時,相當於是儲存了所有的引數,然後載入所有的引數。

載入模型時變數缺失情況

我們可以先將模型定義的變數輸出看一下,得到變數資訊方式有很多

以resnet50為例

    from tensorflow.contrib.slim.nets import resnet_v1
	slim = tensorflow.contrib.slim
	
    inputdata = tf.placeholder(tf.float32, shape=(1, 224, 224, 3), name='input')
    net, end_points = resnet_v1.resnet_v1_50(inputdata, 1000, is_training=False)

我們在構建好網路後,載入網路中定義的變數

# way1
variables_to_resotre =
tf.global_variables() # way2 variables_to_resotre = slim.get_variables_to_restore()

輸出這些變數的型別,可發現這些變數為列表型別。
那麼我們只需要刪除掉我們不需要的元素即可。
使用元素的name屬性可得到其名稱
在這裡插入圖片描述

刪除樣例如下:

#得到該網路中,所有可以載入的引數
variables = tf.slim.get_variables_to_restore()
#刪除output層中的引數
variables_to_resotre = [v for v in varialbes if v.name !='output'
] #構建這部分引數的saver saver = tf.train.Saver(variables_to_restore) saver.restore(sess,'model.ckpt')