TensorFlow模型儲存和載入方法彙總
技術標籤:tensorflow
目錄
一、TensorFlow常規模型載入方法
儲存模型
tf.train.Saver()類,.save(sess, ckpt檔案目錄)方法
引數名稱 | 功能說明 | 預設值 |
var_list | Saver中儲存變數集合 | 全域性變數集合 |
reshape | 載入時是否恢復變數形狀 | True |
sharded | 是否將變數輪循放在所有裝置上 | True |
max_to_keep | 保留最近檢查點個數 | 5 |
restore_sequentially | 是否按順序恢復變數,模型較大時順序恢復記憶體消耗小 | True |
var_list是字典形式{變數名字串: 變數符號},相對應的restore也根據同樣形式的字典將ckpt中的字串對應的變數載入給程式中的符號。
如果Saver給定了字典作為載入方式,則按照字典來,如:saver
=
tf.train.Saver({
"v/ExponentialMovingAverage"
:v}),否則每個變數尋找自己的name屬性在ckpt中的對應值進行載入。
載入模型
當我們基於checkpoint檔案(ckpt)載入引數時,實際上我們使用Saver.restore取代了initializer的初始化
checkpoint檔案會記錄儲存資訊,通過它可以定位最新儲存的模型:
1 2 |
|
.meta檔案儲存了當前圖結構
.data檔案儲存了當前引數名和值
.index檔案儲存了輔助索引資訊
.data檔案可以查詢到引數名和引數值,使用下面的命令可以查詢儲存在檔案中的全部變數{名:值}對,
1 2 |
|
tf.train.import_meta_graph函式給出model.ckpt-n.meta的路徑後會載入圖結構,並返回saver物件
1 |
|
tf.train.Saver函式會返回載入預設圖的saver物件,saver物件初始化時可以指定變數對映方式,根據名字對映變數(『TensorFlow』滑動平均)
1 |
|
saver.restore函式給出model.ckpt-n的路徑後會自動尋找引數名-值檔案進行載入
1 2 |
|
1.不載入圖結構,只加載引數
由於實際上我們引數儲存的都是Variable變數的值,所以其他的引數值(例如batch_size)等,我們在restore時可能希望修改,但是圖結構在train時一般就已經確定了,所以我們可以使用tf.Graph().as_default()新建一個預設圖(建議使用上下文環境),利用這個新圖修改和變數無關的參值大小,從而達到目的。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
|
2.載入圖結構和引數
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 |
|
注意,在所有兩種方式中都可以通過呼叫節點名稱使用節點輸出張量,節點.name屬性返回節點名稱。
3.簡化版本
1 2 3 4 5 6 7 8 9 10 11 12 |
|
二、TensorFlow二進位制模型載入方法
這種載入方法一般是對應網上各大公司已經訓練好的網路模型進行修改的工作
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 |
|
上面兩篇都使用了二進位制載入模型的方式
三、二進位制模型製作
這節是關於tensorflow的Freezing,字面意思是冷凍,可理解為整合合併;整合什麼呢,就是將模型檔案和權重檔案整合合併為一個檔案,主要用途是便於釋出。
tensorflow在訓練過程中,通常不會將權重資料儲存的格式檔案裡(這裡我理解是模型檔案),反而是分開儲存在一個叫checkpoint的檢查點檔案裡,當初始化時,再通過模型檔案裡的變數Op節點來從checkoupoint檔案讀取資料並初始化變數。這種模型和權重資料分開儲存的情況,使得釋出產品時不是那麼方便,我們可以將tf的圖和引數檔案整合進一個字尾為pb的二進位制檔案中,由於整合過程回將變數轉化為常量,所以我們在日後讀取模型檔案時不能夠進行訓練,僅能向前傳播,而且我們在儲存時需要指定節點名稱。
將圖變數轉換為常量的API:tf.graph_util.convert_variables_to_constants
轉換後的graph_def物件轉換為二進位制資料(graph_def.SerializeToString())後,寫入pb即可。
1 2 3 4 5 6 7 8 9 10 11 12 13 |
|
我們可以直接檢視gd:
node {
name: "v1"
op: "Const"
attr {
key: "dtype"
value {
type: DT_FLOAT
}
}
attr {
key: "value"
value {
tensor {
dtype: DT_FLOAT
tensor_shape {
dim {
size: 1
}
}
float_val: 1.0
}
}
}
}
……
node {
name: "add"
op: "Add"
input: "v1/read"
input: "v2/read"
attr {
key: "T"
value {
type: DT_FLOAT
}
}
}
library {
}
四、從圖上讀取張量
上面的程式碼實際上已經包含了本小節的內容,但是由於從圖上讀取特定的張量是如此的重要,所以我仍然單獨的補充上這部分的內容。
無論如何,想要獲取特定的張量我們必須要有張量的名稱和圖的控制代碼,比如 'import/pool_3/_reshape:0' 這種,有了張量名和圖,索引就很簡單了。
從二進位制模型載入張量
第二小節的程式碼很好的展示了這種情況
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 |
|
從當前圖中獲取對應張量
這個就是很普通的情況,從我們當前操作的圖中獲取某個張量,用於feed啦或者用於輸出等操作,API也很簡單,用法如下:
g.get_tensor_by_name('import/pool_3/_reshape:0')
g表示當前圖控制代碼,可以簡單的使用 g = tf.get_default_graph() 獲取。
從圖中獲取節點資訊
有的時候我們對於模型中的節點並不夠了解,此時我們可以通過圖控制代碼來查詢圖的構造:
1 2 |
|
這個操作將返回圖的構造結構。從這裡,對比前面的程式碼,我們也可以瞭解到:graph_def 實際就是圖的結構資訊儲存形式,我們可以將之還原為圖(二進位制模型載入程式碼中展示了),也可以從圖中將之提取出來(本部分程式碼)。