1. 程式人生 > >TensorFlow 入門 3 ——變數管理和模型持久化

TensorFlow 入門 3 ——變數管理和模型持久化

變數管理

TensorFlow 提供了通過變數名稱來建立或者獲取一個變數的機制。通過這個機制,在不同的函式中可以直接通過變數名稱來使用變數,而不需要將變數通過引數的形式到處傳遞。TensorFlow中通過變數名稱獲取變數的機制主要是通過tf.get_variable和tf.variable_scope函式實現的。

TensorFLow還提供了tf.get_variable函式來建立或者獲取變數,tf.get_variable用於建立變數時,其功能和tf.Variable基本是等價的。tf.get_variable中的初始化方法(initializer)的引數和tf.Variable的初始化過程也類似,initializer函式和tf.Variable的初始化方法是一一對應的。

#以下兩個定義是等價的

##首先根據"v"這個名稱來建立一個引數,如果建立失敗(比如已經有同名的引數),那麼這個程式就會報錯。(防止變數重複建立)
v = tf.get_variable("v", shape=[1], inittializer=tf.constant_initializer(1.0))
v = tf.Variable(tf.constant(1.0, shape=[1]), name="v"))
##**The best way to create a variable is to call the tf.get_variable function.**

TensorFlow中提供的initializer函式和隨機數以及常數生成函式大部分是意義對應的。
這裡寫圖片描述


tf.get_variable和tf.Variable最大的區別就在於指定變數名稱的引數。對於tf.Variable函式,變數名稱是一個可選的引數。對於tf.get_variable函式,變數名稱是一個必填的引數,tf.get_variable會根據這個名稱去建立或者獲取變數。

tf.variable_scope

如果需要通過tf.get_variable獲取一個已經建立的變數,需要通過tf.variable_scope函式來生成一個上下文管理器,並明確指定在這個上下文管理器中,tf.get_variable將直接獲取已建立的變數。下面一段程式碼說明了如何通過tf.variable_scope函式來控制tf.get_variable函式獲取建立過的變數。
這裡寫圖片描述


通過tf.variable_scope函式可以控制tf.get_variable函式的語義。當tf.variable_scope函式的引數reuse=True生成上下文管理器時,該上下文管理器內的所有的tf.get_variable函式會直接獲取已經建立的變數,如果變數不存在則報錯;當tf.variable_scope函式的引數reuse=False或者None時建立的上下文管理器中,tf.get_variable函式則直接建立新的變數,若同名的變數已經存在則報錯。

另tf.variable_scope函式是可以巢狀使用的。巢狀的時候,若某層上下文管理器未宣告reuse引數,則該層上下文管理器的reuse引數與其外層保持一致。
這裡寫圖片描述
tf.variable_scope函式提供了一個管理變數名稱空間的方式。在tf.variable_scope中建立的變數,名稱.name中名稱前面會加入名稱空間的名稱,並通過“/”來分隔名稱空間的名稱和變數的名稱。tf.get_variable(“foou/baru/u”, [1]),可以通過帶名稱空間名稱的變數名來獲取其名稱空間下的變數。
這裡寫圖片描述

模型持久化

當我們使用 tensorflow 訓練神經網路的時候,模型持久化對於我們的訓練有很重要的作用。

  1. 如果我們的神經網路比較複雜,訓練資料比較多,那麼我們的模型訓練就會耗時很長,如果在訓練過程中出現某些不可預計的錯誤,導致我們的訓練意外終止,那麼我們將會前功盡棄。為了避免這個問題,我們就可以通過模型持久化(儲存為CKPT格式)來暫存我們訓練過程中的臨時資料。

  2. 如果我們訓練的模型需要提供給使用者做離線的預測,那麼我們只需要前向傳播的過程,只需得到預測值就可以了,這個時候我們就可以通過模型持久化(儲存為PB格式)只儲存前向傳播中需要的變數並將變數的值固定下來,這個時候只需使用者提供一個輸入,我們就可以通過模型得到一個輸出給使用者。

持久化程式碼實現

TensorFlow提供了一個非常簡單的API來儲存和還原一個神經網路模型。這個API就是tf.train.Saver類。以下程式碼給出了儲存TensorFlow計算圖的方法。

儲存
import tensorflow as tf
#宣告兩個變數並計算他們的和
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")
result = v1 + v2

init_op = tf.global_initialize_variables()
#宣告tf.train.Saver類用於儲存模型
saver = tf.train.Saver()

with tf.Session() as sess:
    sess.run(init_op)
    #將模型儲存到/path/to/model/model.ckpt檔案。
    saver.save(sess, "/path/to/model/model.ckpt")

上面的程式碼實現了持久化一個簡單的TensorFlow模型的功能。在這個段程式碼中,通過saver.save函式將TensorFlow模型儲存到/path/to/model/model.ckpt檔案中。TensorFlow模型一般會儲存在後綴為.ckpt的檔案中。同時在這個檔案目錄下會出現三個檔案。這是因為TensorFlow會將計算圖的結構和圖上引數取值分開儲存。

  1. checkpoint檔案儲存了一個目錄下所有的模型檔案列表,這個檔案是tf.train.Saver類自動生成且自動維護的。在 checkpoint檔案中維護了由一個tf.train.Saver類持久化的所有TensorFlow模型檔案的檔名。當某個儲存的TensorFlow模型檔案被刪除時,這個模型所對應的檔名也會從checkpoint檔案中刪除。checkpoint中內容的格式為CheckpointState Protocol Buffer.
  2. model.ckpt.meta檔案儲存了TensorFlow計算圖的結構,可以理解為神經網路的網路結構 。TensorFlow通過元圖(MetaGraph)來記錄計算圖中節點的資訊以及執行計算圖中節點所需要的元資料。TensorFlow中元圖是由MetaGraphDef Protocol Buffer定義的。MetaGraphDef 中的內容構成了TensorFlow持久化時的第一個檔案。儲存MetaGraphDef 資訊的檔案預設以.meta為字尾名,檔案model.ckpt.meta中儲存的就是元圖資料。
  3. model.ckpt檔案儲存了TensorFlow程式中每一個變數的取值,這個檔案是通過SSTable格式儲存的,可以大致理解為就是一個(key,value)列表。model.ckpt檔案中列表的第一行描述了檔案的元資訊,比如在這個檔案中儲存的變數列表。列表剩下的每一行儲存了一個變數的片段,變數片段的資訊是通過SavedSlice Protocol Buffer定義的。SavedSlice型別中儲存了變數的名稱、當前片段的資訊以及變數取值。TensorFlow提供了tf.train.NewCheckpointReader類來檢視model.ckpt檔案中儲存的變數資訊。如何使用tf.train.NewCheckpointReader類這裡不做說明,自查。
讀取
# Part2: 載入TensorFlow模型的方法  

import tensorflow as tf  

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")  
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")  
result = v1 + v2  

saver = tf.train.Saver()  

with tf.Session() as sess:  
    saver.restore(sess, "./Model/model.ckpt") # 注意此處路徑前新增"./"  
    print(sess.run(result)) # [ 3.]  


# Part3: 若不希望重複定義計算圖上的運算,可直接載入已經持久化的圖  

import tensorflow as tf  

saver = tf.train.import_meta_graph("Model/model.ckpt.meta")  

with tf.Session() as sess:  
    saver.restore(sess, "./Model/model.ckpt") # 注意路徑寫法  
    print(sess.run(tf.get_default_graph().get_tensor_by_name("add:0"))) # [ 3.]  

在上面給出的程式中,預設儲存和載入了TensorFlow計算圖上定義的全部變數。有時可能只需要儲存或者載入部分變數。 比如,可能有一個之前訓練好的5層神經網路模型,但現在想寫一個6層的神經網路,那麼可以將之前5層神經網路中的引數直接載入到新的模型,而僅僅將最後一層神經網路重新訓練。為了儲存或者載入部分變數,在宣告tf.train.Saver類時可以提供一個列表來指定需要儲存或者載入的變數。比如在載入模型的程式碼中使用saver = tf.train.Saver([v1])命令來構建tf.train.Saver類,那麼只有變數v1會被載入進來。

tf.train.Saver類也支援在儲存和載入時給變數重新命名,宣告Saver類物件的時候使用一個字典dict重新命名變數即可,{“已儲存的變數的名稱name”: 重新命名變數名},saver = tf.train.Saver({“v1”:u1, “v2”: u2})即原來名稱name為v1的變數現在載入到變數u1(名稱name為other-v1)中。

# Part4: tf.train.Saver類也支援在儲存和載入時給變數重新命名  
import tensorflow as tf  

# 宣告的變數名稱name與已儲存的模型中的變數名稱name不一致  
u1 = tf.Variable(tf.constant(1.0, shape=[1]), name="other-v1")  
u2 = tf.Variable(tf.constant(2.0, shape=[1]), name="other-v2")  
result = u1 + u2  

# 若直接生命Saver類物件,會報錯變數找不到  
# 使用一個字典dict重新命名變數即可,{"已儲存的變數的名稱name": 重新命名變數名}  
# 原來名稱name為v1的變數現在載入到變數u1(名稱name為other-v1)中  
saver = tf.train.Saver({"v1": u1, "v2": u2})  

with tf.Session() as sess:  
    saver.restore(sess, "./Model/model.ckpt")  
    print(sess.run(result)) # [ 3.]  

使用tf. train. Saver 會儲存執行TensorFlow 程式所需要的全部資訊,然而有時並不需要某些資訊。比如在測試或者離線預測時,只需要知道如何從神經網路的輸入層經過前向傳播計算得到輸出層即可,而不需要類似於變數初始化、模型儲存等輔助節點的資訊。在遷移學習中,會遇到類似的情況。而且,將變數取值和計算圖結構分成不同的檔案儲存有時候也不方便,於是TensorFlow 提供了convert_variables_to_constants 函式,通過這個函式可以將計算圖中的變數及其取值通過常量的方式儲存,這樣整個TensorFlow 計算圖可以統一存放在一個檔案中。下面的程式提供了一個樣例。

import tensorflow as tf  
from tensorflow.python.framework import graph_util  

v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1")  
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2")  
result = v1 + v2  

init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init op)
    #匯出當前計算圖的GraphDef 部分,只需要這一部分就可以完成從輸入層到輸出層的計算過程。
    graph_def = tf.get_default_graph().as_graph_def()
    #將圖中的變數及其取值轉化為常量,同時將圖中不必要的節點去掉。在5.4.2 小節中將會看
    #到一些系統運算也會被轉化為計算圖中的節點(比如變數初始化操作)。如果只關心程式中定
    #義的某些計算時,和這些計算無關的節點就沒有必要匯出並儲存了。在下面一行程式碼中,最
    #後一個引數[ 'add'] 給出了需要儲存的節點名稱。add 節點是上面定義的兩個變數相加的
    #操作。注意這裡給出的是計算節點的名稱,所以沒有後面的O
    output_graph_def = graph_util.convert_variables_to_constants(sess , graph_def, ['add'])
    #將匯出的模型存入檔案。
    with tf.gfile.GFile("/path/to/model/combined_model.pb" , "wb") as f:
        f.write(output_graph_def.SerializeToString())

通過下面的程式可以直接計算定義的加法運算的結果。當只需要得到計算圖中某個節點的取值時,這提供了一個更加方便的方法。(這個可以用來實現遷移學習)

import tensorflow as tf  
from tensorflow.python.platform import gfile  

with tf.Session() as sess:  
    model_filename = "Model/combined_model.pb"  
    #讀取儲存的模型檔案,並將檔案解析成對應的GraphDef Protocol Buffer。
    with gfile.FastGFile(model_filename, 'rb') as f:  
        graph_def = tf.GraphDef()  
        graph_def.ParseFromString(f.read())  

    result = tf.import_graph_def(graph_def, return_elements=["add:0"])  
    print(sess.run(result)) # [array([ 3.], dtype=float32)]  

持久化原理及資料格式

TensorFlow 是一個通過圖的形式來表述計算的程式設計系統,TensorFlow 程式中的所有計算都會被表達為計算圖上的節點。TensorFlow 通過元圖( MetaGraph )來記錄計算圖中節點的資訊以及執行計算圖中節點所需要的元資料。TensorFlow 中元圖是由MetaGraphDef Protocol Buffer 定義的。MetaGraphDef 中的內容就構成了TensorFlow 持久化時的第一個檔案。以下程式碼給出了MetaGraphDef 型別的定義。

message MetaGraphDef {
    MetaInfoDef meta_info_def = 1;
    GraphDef graph_def = 2;
    SaverDef saver_def = 3;
    map<string, CollectionDef> collection_def = 4;
    map<string, SignatureDef> signature_def = 5;
}

從上面的程式碼中可以看到,元圖中主要記錄了5 類資訊。儲存MetaGraphDef 資訊的檔案預設以meta 為字尾名,檔案model. ckpt. meta 中儲存的就是元圖的資料。為了方便除錯, TensorFlow 提供了export_ meta _graph 函式,這個函式支援以Json 格式匯出MetaGraphDef Protocol Buffer 。以下程式碼展示瞭如何使用這個函式。

import tensorflow as tf
#定義變數相加的計算。
v1 = tf.Variable(tf.constant(1.0, shape=[1]), name="v1" )
v2 = tf.Variable(tf.constant(2.0, shape=[1]), name="v2" )
result1 = v1 + v2
saver = tf.train.Saver()
#通過export_meta_graph 函式匯出TensorFlow 計算圖的元圖,並儲存為json 格式。
saver.export_meta_graph("/path/to/model.ckpt.meda.json", as_text=True)

通過上面給出的程式碼,可以將計算圖元圖以Json 的格式匯出並存儲在model.ckpt.meta.json 檔案中。下文將結合model.ckpt.meta.json 檔案具體介紹TensorFlow 元圖中儲存的資訊

meta_info_def屬性

meta_info_def 屬性是通過MetalnfoDef 定義的,它記錄了TensorFlow 計算圖中的元資料以及TensorFlow 程式中所有使用到的運算方法的資訊。下面是MetalnfoDef Protocol Buffer 的定義:

message MetaInfoDef {
    string meta_graph_version = 1; #計算圖的版本號
    OpList stripped_op_list = 2;
    google.protobuf.Any any_info = 3;
    repeated string_tags = 4;  #使用者指定的一些標籤
}

stripped_op_list 屬性記錄了TensorFlow 計算圖上使用到的所有運算方法的資訊。注意stripped_op_list屬性儲存的是TensorFlow 運算方法的資訊,所以如果某一個運算在TensorFlow 計算圖中出現了多次,那麼在stripped_op_list也只會出現一次。
stripped_op_list 屬性的型別是OpList。OpList 型別是一個OpDef 型別
的列表,以下程式碼給出了OpDef 型別的定義:

message OpDef {
    string name = 1;
    repeated ArgDef input_arg = 2;
    repeated ArgDef output arg = 3 ;
    repeated AttrDef attr = 4 ;
    string summary = 5;
    string description = 6;
    OpDeprecation deprecation = 8;
    boo1 is_commutative = 18;
    bool is_aggregate = 16
    bool is_stateful = 17;
    bool allows_uninitialized_input = 19 ;

OpDef 型別中前四個屬性定義了一個運算最核心的資訊。OpDef 中的第一個屬性name定義了運算的名稱,這也是一個運算唯一的識別符號。在TensorFlow 計算圖元圖的其他屬性中,比如下面將要介紹的GraphDef 屬性,將通過運算名稱來引用不同的運算。OpDef 的第二和第三個屬性為input_arg 和output_arg,它們定義了運算的輸入和輸出。因為輸入輸出都可以有多個,所以這兩個屬性都是列表(repeated) 。第四個屬性attr給出了其他的運算引數資訊。

graph_def屬性

graph_def屬性主要記錄了TensorFlow 計算圖上的節點資訊。TensorFlow 計算圖的每一個節點對應了TensorFlow 程式中的一個運算。因為在meta _info_def屬性中己經包含了所有運算的具體資訊,所以graph_def屬性只關注運算的連線結構。graph_def屬性是通過GraphDef Protocol Buffer 定義的, GraphDef 主要包含了一個NodeDef 型別的列表。以下程式碼給出了GraphDef 和NodeDef 型別中包含的資訊:

message GraphDef {
    repeated NodeDef node = 1;
    VersionDef versions = 4 ;
} ;

message NodeDef {
    string name = 1;
    string op = 2;
    repeated string input = 3;
    string device = 4;
    map<string, AttrValue> attr = 5;
}

GraphDef 中的versions 屬性比較簡單,它主要儲存了TensorFlow 的版本號。GraphDef的主要資訊都存在node屬性中,它記錄了TensorFlow 計算圖上所有的節點資訊。

  1. 和其他屬性類似,NodeDef 型別中有一個名稱屬性name ,它是一個節點的唯一識別符號。在TensorFlow 程式中可以通過節點的名稱來獲取相應的節點。
  2. NodeDef 型別中的op 屬性給出了該節點使用的TensorFlow 運算方法的名稱,通過這個名稱可以在TensorFlow 計算圖元圖的meta info def 屬性中找到該運算的具體資訊。
  3. NodeDef 型別中的input 屬性是一個字串列表,它定義了運算的輸入。input 屬性中每個字串的取值格式為node:src_output ,其中node 部分給出了一個節點的名稱, src _output部分表明了這個輸入是指定節點的第幾個輸出。當src_output 為0 時,可以省略: src_output這個部分。比如node:0 表示名稱為node 的節點的第一個輸出,它也可以被記為node 。
  4. NodeDef 型別中的device 屬性指定了處理這個運算的裝置。執行TensorFlow 運算的裝置可以是本地機器的CPU 或者GPU ,也可以是一臺遠端的機器CPU 或者GPU 。

saver_def 屬性

saver_def 屬性中記錄了持久化模型時需要用到的一些引數,比如儲存到檔案的檔名、儲存操作和載入操作的名稱以及儲存頻率、清理歷史記錄等。saver_def 屬性的型別為SaverDef,其定義如下。

message SaverDef {
    string filename_tensor_name = 1 ;
    string save_tensor_name = 2;
    string restore_op_name = 3;
    int32 max_to_keep = 4;
    bool sharded = 5;
    float keep_checkpoint_every_n_hours = 6;
    enum CheckpointFormatVersion {
        LEGACY = 0;
        V1 = 1;
        V2 = 2;
    }
    CheckpointFormatVersion version = 7;
}

filename_tensor_name 屬性給出了儲存檔名的張量名稱,這個張量就是節點save/Const 的第一個輸出。save_tensor_name 屬性給出了持久化TensorFlow 模型的運算所對應的節點名稱。從上面的檔案中可以看出,這個節點就是在graph_def 屬性中給出的save/control_dependency 節點。和持久化TensorFlow 模型運算對應的是載入TensorFlow 模型的運算,這個運算的名稱由restore_op_name 屬性指定。max_to_keep 屬性和keep_checkpoint_every_n_hours 屬性設定了tf.train.Saver 類清理之前儲存的模型的策略。比如當max_to_keep 為5 的時候,在第六次呼叫saver.save 時,第一次儲存的模型就會被自動刪除。通過設定keep_checkpoint_every_n_hours ,每n 小時可以在max_t_keep 的基礎上多儲存一個模型。

collection_def屬性

在TensorFlow 的計算圖( tf. Graph) 中可以維護不同集合, 而維護這些集合的底層實現就是通過collection_def 這個屬性。collection_def 屬性是一個從集合名稱到集合內容的對映,其中集合名稱為字串,而集合內容為CollectionDef Protocol Buffer 。以下程式碼給出了CollectionDef 型別的定義。

message CollectionDef {
    message NodeList {
        repeated string value = 1;
    }
    message BytesList {
        repeated bytes value = 1;
    }
    message Int64List {
        repeated int64 va1ue = 1 [packed = true];
        }
    message FloatList {
        repeated f1oat value = 1 [packed = true];
    }
    message AnyList {
        repeated google.protobuf.Any value = 1;
    }
    oneof kind {
        NodeList node_list = 1;
        BytesList bytes_list = 2;
        Int64List int64_list = 3 ;
        FloatList f1oat_list = 4;
        AnyList any_list = 5;
    }
}

通過上面的定義可以看出, TensorFlow 計算圖上的集合主要可以維護4 類不同的集合。NodeList 用於維護計算圖上節點的集合。BytesList 可以維護字串或者系列化之後的Procotol Buffer 的集合。比如張量是通過Protocol Buffer 表示的, 而張量的集合是通過BytesList 維護的。

mode1.ckpt 檔案中列表的第一行描述了檔案的元資訊,比如在這個檔案中儲存的變數列表。列表剩下的每一行儲存了一個變數的片段,變數片段的資訊是通過SavedSliceProtocol Buffer 定義的。SavedSlice 型別中儲存了變數的名稱、當前片段的資訊以及變數取值。TensroFlow 提供了tf.train.NewCheckpoin tReader 類來檢視mode1.ckpt 檔案中儲存的變數資訊。
最後一個檔案的名字是固定的,叫checkpoint。這個檔案是tf.train.Saver 類自動生成且自動維護的。在checkpoint 檔案中維護了由一個tf.train.Saver 類持久化的所有TensorFlow模型檔案的檔名。當某個儲存的TensorFlow 模型檔案被刪除時,這個模型所對應的檔名也會從checkpoint 檔案中刪除。checkpoint 中內容的格式為CheckpointState Protocol Buffer。