TensorFlow學習筆記(8)--網路模型的儲存和讀取
之前的筆記裡實現了softmax迴歸分類、簡單的含有一個隱層的神經網路、卷積神經網路等等,但是這些程式碼在訓練完成之後就直接退出了,並沒有將訓練得到的模型儲存下來方便下次直接使用。為了讓訓練結果可以複用,需要將訓練好的神經網路模型持久化,這就是這篇筆記裡要寫的東西。
TensorFlow提供了一個非常簡單的API,即tf.train.Saver
類來儲存和還原一個神經網路模型。
下面程式碼給出了儲存TensorFlow模型的方法:
import tensorflow as tf
# 宣告兩個變數
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1" )
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
init_op = tf.global_variables_initializer() # 初始化全部變數
saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 宣告tf.train.Saver類用於儲存模型
with tf.Session() as sess:
sess.run(init_op)
print("v1:", sess.run(v1)) # 列印v1、v2的值一會讀取之後對比
print("v2:" , sess.run(v2))
saver_path = saver.save(sess, "save/model.ckpt") # 將模型儲存到save/model.ckpt檔案
print("Model saved in file:", saver_path)
注:Saver方法已經發生了更改,現在是V2版本,tf.train.Saver(write_version=tf.train.SaverDef.V1)括號里加入該引數可繼續使用V1,但會報warning,可忽略。若使用saver = tf.train.Saver()則預設使用當前的版本(V2),儲存後在save這個資料夾中會出現4個檔案,比V1版多出model.ckpt.data-00000-of-00001
這段程式碼中,通過saver.save
函式將TensorFlow模型儲存到了save/model.ckpt檔案中,這裡程式碼中指定路徑為"save/model.ckpt"
,也就是儲存到了當前程式所在資料夾裡面的save
資料夾中。
TensorFlow模型會儲存在後綴為.ckpt
的檔案中。儲存後在save這個資料夾中會出現3個檔案,因為TensorFlow會將計算圖的結構和圖上引數取值分開儲存。
checkpoint
檔案儲存了一個目錄下所有的模型檔案列表,這個檔案是tf.train.Saver
類自動生成且自動維護的。在checkpoint
檔案中維護了由一個tf.train.Saver
類持久化的所有TensorFlow模型檔案的檔名。當某個儲存的TensorFlow模型檔案被刪除時,這個模型所對應的檔名也會從checkpoint
檔案中刪除。checkpoint
中內容的格式為CheckpointState Protocol Buffer.model.ckpt.meta
檔案儲存了TensorFlow計算圖的結構,可以理解為神經網路的網路結構
TensorFlow通過元圖(MetaGraph)來記錄計算圖中節點的資訊以及執行計算圖中節點所需要的元資料。TensorFlow中元圖是由MetaGraphDef Protocol Buffer定義的。MetaGraphDef 中的內容構成了TensorFlow持久化時的第一個檔案。儲存MetaGraphDef 資訊的檔案預設以.meta為字尾名,檔案model.ckpt.meta
中儲存的就是元圖資料。model.ckpt
檔案儲存了TensorFlow程式中每一個變數的取值,這個檔案是通過SSTable格式儲存的,可以大致理解為就是一個(key,value)列表。model.ckpt
檔案中列表的第一行描述了檔案的元資訊,比如在這個檔案中儲存的變數列表。列表剩下的每一行儲存了一個變數的片段,變數片段的資訊是通過SavedSlice Protocol Buffer定義的。SavedSlice型別中儲存了變數的名稱、當前片段的資訊以及變數取值。TensorFlow提供了tf.train.NewCheckpointReader
類來檢視model.ckpt
檔案中儲存的變數資訊。如何使用tf.train.NewCheckpointReader
類這裡不做說明,自查。
下面程式碼給出了載入TensorFlow模型的方法:
可以對比一下v1、v2的值是隨機初始化的值還是和之前儲存的值是一樣的?
import tensorflow as tf
# 使用和儲存模型程式碼中一樣的方式來宣告變數
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
saver = tf.train.Saver() # 宣告tf.train.Saver類用於儲存模型
with tf.Session() as sess:
saver.restore(sess, "save/model.ckpt") # 即將固化到硬碟中的Session從儲存路徑再讀取出來
print("v1:", sess.run(v1)) # 列印v1、v2的值和之前的進行對比
print("v2:", sess.run(v2))
print("Model Restored")
執行結果:
v1: [[ 0.76705766 1.82217288]]
v2: [[-0.98012197 1.2369734 0.5797025 ]
[ 2.50458145 0.81897354 0.07858191]]
Model Restored
這段載入模型的程式碼基本上和儲存模型的程式碼是一樣的。也是先定義了TensorFlow計算圖上所有的運算,並聲明瞭一個tf.train.Saver
類。兩段唯一的不同是,在載入模型的程式碼中沒有執行變數的初始化過程,而是將變數的值通過已經儲存的模型載入進來。
也就是說使用TensorFlow完成了一次模型的儲存和讀取的操作。
如果不希望重複定義圖上的運算,也可以直接載入已經持久化的圖:
import tensorflow as tf
# 在下面的程式碼中,預設載入了TensorFlow計算圖上定義的全部變數
# 直接載入持久化的圖
saver = tf.train.import_meta_graph("save/model.ckpt.meta")
with tf.Session() as sess:
saver.restore(sess, "save/model.ckpt")
# 通過張量的名稱來獲取張量
print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))
執行程式,輸出:
[[ 0.76705766 1.82217288]]
有時可能只需要儲存或者載入部分變數。
比如,可能有一個之前訓練好的5層神經網路模型,但現在想寫一個6層的神經網路,那麼可以將之前5層神經網路中的引數直接載入到新的模型,而僅僅將最後一層神經網路重新訓練。
為了儲存或者載入部分變數,在宣告tf.train.Saver
類時可以提供一個列表來指定需要儲存或者載入的變數。比如在載入模型的程式碼中使用saver = tf.train.Saver([v1])
命令來構建tf.train.Saver
類,那麼只有變數v1會被載入進來。
…未完待續