1. 程式人生 > >TensorFlow學習筆記(8)--網路模型的儲存和讀取

TensorFlow學習筆記(8)--網路模型的儲存和讀取

之前的筆記裡實現了softmax迴歸分類、簡單的含有一個隱層的神經網路、卷積神經網路等等,但是這些程式碼在訓練完成之後就直接退出了,並沒有將訓練得到的模型儲存下來方便下次直接使用。為了讓訓練結果可以複用,需要將訓練好的神經網路模型持久化,這就是這篇筆記裡要寫的東西。

TensorFlow提供了一個非常簡單的API,即tf.train.Saver類來儲存和還原一個神經網路模型。

下面程式碼給出了儲存TensorFlow模型的方法:

import tensorflow as tf

# 宣告兩個變數
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1"
) v2 = tf.Variable(tf.random_normal([2, 3]), name="v2") init_op = tf.global_variables_initializer() # 初始化全部變數 saver = tf.train.Saver(write_version=tf.train.SaverDef.V1) # 宣告tf.train.Saver類用於儲存模型 with tf.Session() as sess: sess.run(init_op) print("v1:", sess.run(v1)) # 列印v1、v2的值一會讀取之後對比 print("v2:"
, sess.run(v2)) saver_path = saver.save(sess, "save/model.ckpt") # 將模型儲存到save/model.ckpt檔案 print("Model saved in file:", saver_path)

注:Saver方法已經發生了更改,現在是V2版本,tf.train.Saver(write_version=tf.train.SaverDef.V1)括號里加入該引數可繼續使用V1,但會報warning,可忽略。若使用saver = tf.train.Saver()則預設使用當前的版本(V2),儲存後在save這個資料夾中會出現4個檔案,比V1版多出model.ckpt.data-00000-of-00001

這個檔案,這點感謝評論裡那位朋友指出。至於這個檔案的含義到目前我仍不是很清楚,也沒查到具體資料,TensorFlow15年底開源到現在很多類啊函式都一直髮生著變動,或被更新或被棄用,可能一些程式碼在當時是沒問題的,但過了一大段時間後再跑可能就會報錯,在此註明事件時間:2017.4.30

這段程式碼中,通過saver.save函式將TensorFlow模型儲存到了save/model.ckpt檔案中,這裡程式碼中指定路徑為"save/model.ckpt",也就是儲存到了當前程式所在資料夾裡面的save資料夾中。

TensorFlow模型會儲存在後綴為.ckpt的檔案中。儲存後在save這個資料夾中會出現3個檔案,因為TensorFlow會將計算圖的結構和圖上引數取值分開儲存。

  • checkpoint檔案儲存了一個目錄下所有的模型檔案列表,這個檔案是tf.train.Saver類自動生成且自動維護的。在 checkpoint檔案中維護了由一個tf.train.Saver類持久化的所有TensorFlow模型檔案的檔名。當某個儲存的TensorFlow模型檔案被刪除時,這個模型所對應的檔名也會從checkpoint檔案中刪除。checkpoint中內容的格式為CheckpointState Protocol Buffer.

  • model.ckpt.meta檔案儲存了TensorFlow計算圖的結構,可以理解為神經網路的網路結構
    TensorFlow通過元圖(MetaGraph)來記錄計算圖中節點的資訊以及執行計算圖中節點所需要的元資料。TensorFlow中元圖是由MetaGraphDef Protocol Buffer定義的。MetaGraphDef 中的內容構成了TensorFlow持久化時的第一個檔案。儲存MetaGraphDef 資訊的檔案預設以.meta為字尾名,檔案model.ckpt.meta中儲存的就是元圖資料。

  • model.ckpt檔案儲存了TensorFlow程式中每一個變數的取值,這個檔案是通過SSTable格式儲存的,可以大致理解為就是一個(key,value)列表。model.ckpt檔案中列表的第一行描述了檔案的元資訊,比如在這個檔案中儲存的變數列表。列表剩下的每一行儲存了一個變數的片段,變數片段的資訊是通過SavedSlice Protocol Buffer定義的。SavedSlice型別中儲存了變數的名稱、當前片段的資訊以及變數取值。TensorFlow提供了tf.train.NewCheckpointReader類來檢視model.ckpt檔案中儲存的變數資訊。如何使用tf.train.NewCheckpointReader類這裡不做說明,自查。

這裡寫圖片描述

下面程式碼給出了載入TensorFlow模型的方法:

可以對比一下v1、v2的值是隨機初始化的值還是和之前儲存的值是一樣的?

import tensorflow as tf

# 使用和儲存模型程式碼中一樣的方式來宣告變數
v1 = tf.Variable(tf.random_normal([1, 2]), name="v1")
v2 = tf.Variable(tf.random_normal([2, 3]), name="v2")
saver = tf.train.Saver() # 宣告tf.train.Saver類用於儲存模型
with tf.Session() as sess:
    saver.restore(sess, "save/model.ckpt") # 即將固化到硬碟中的Session從儲存路徑再讀取出來
    print("v1:", sess.run(v1)) # 列印v1、v2的值和之前的進行對比
    print("v2:", sess.run(v2))
    print("Model Restored")

執行結果:

v1: [[ 0.76705766  1.82217288]]
v2: [[-0.98012197  1.2369734   0.5797025 ]
 [ 2.50458145  0.81897354  0.07858191]]
Model Restored

這段載入模型的程式碼基本上和儲存模型的程式碼是一樣的。也是先定義了TensorFlow計算圖上所有的運算,並聲明瞭一個tf.train.Saver類。兩段唯一的不同是,在載入模型的程式碼中沒有執行變數的初始化過程,而是將變數的值通過已經儲存的模型載入進來。
也就是說使用TensorFlow完成了一次模型的儲存和讀取的操作。

如果不希望重複定義圖上的運算,也可以直接載入已經持久化的圖:

import tensorflow as tf
# 在下面的程式碼中,預設載入了TensorFlow計算圖上定義的全部變數
# 直接載入持久化的圖
saver = tf.train.import_meta_graph("save/model.ckpt.meta")
with tf.Session() as sess:
    saver.restore(sess, "save/model.ckpt")
    # 通過張量的名稱來獲取張量
    print(sess.run(tf.get_default_graph().get_tensor_by_name("v1:0")))

執行程式,輸出:

[[ 0.76705766  1.82217288]]

有時可能只需要儲存或者載入部分變數。
比如,可能有一個之前訓練好的5層神經網路模型,但現在想寫一個6層的神經網路,那麼可以將之前5層神經網路中的引數直接載入到新的模型,而僅僅將最後一層神經網路重新訓練。

為了儲存或者載入部分變數,在宣告tf.train.Saver類時可以提供一個列表來指定需要儲存或者載入的變數。比如在載入模型的程式碼中使用saver = tf.train.Saver([v1])命令來構建tf.train.Saver類,那麼只有變數v1會被載入進來。

…未完待續