1. 程式人生 > >Tensorflow-GraphDef、MetaGraph、CheckPoint

Tensorflow-GraphDef、MetaGraph、CheckPoint

Tensorflow框架實現的三種圖

 

參考原文:http://www.360doc.com/content/17/1123/18/7669533_706522939.shtml

==========================

Graph:

Tensorflow所執行的程式碼,或者說用python程式碼表達的計算,所描述的物件實際上就是一張計算圖,包含了各個運算節點和用於計算的張量。而Graph_def是圖Graph的序列表示。python所描述的這個graph,並不是在執行Tensorflow,啟動一個Session後就保持不變的,因為Tensorflow在實際執行過程中,真實的計算是會被下放到多CPU,或者GPU、ARM等異構裝置上進行高效能運算的,如果僅僅單純地使用python肯定是無法有效地完成計算的。所以Tensorflow的實際計算過程是這樣的:

Tensorflow先將python程式碼所描繪的圖進行轉換,轉化成Protocol Buffer(即序列化),再通過C/C++/CUDA執行Protocol Buffer所定義的圖。

(Protocol Buffer:

https://www.ibm.com/developerworks/cn/linux/l-cn-gpb/

Tensorflow實戰Google深度學習框架Chapter 2 )

==========================

GraphDef:

從python程式碼描述的Graph中序列化得到的圖就叫做GraphDef。GraphDef可以理解為一種資料結構。GraphDef是由許多叫做NodeDef的Protocol Buffer組成。其中NodeDef也可以理解為是資料結構。(實際上從資料結構的角度上就很好理解這些內容)。GraphDef強調的是操作節點之間的聯絡。Tensorflow中通過NodeDef中的input這一attribute來定義Node之間的連線資訊。

在概念上,NodeDef與python程式碼描繪的Graph中的操作運算節點Operation相對應。可知GraphDef中只有NodeDef,也就是說只有python描述的Graph中的Operation,並沒有Variable。所以這也反映出了GraphDef這個圖強調的是python描述的Graph的連線資訊,並不儲存Variable的相關資訊(注意並不是所有Tensor的相關資訊都不儲存,constant型別的Tensor的相關資訊就會在GraphDef中儲存)。所以如果要從graph_def來構建圖並恢復訓練的話,是不一定能成功的,因為缺少了例如Variable等這些Tensor。

在實際線上Inference中,通常使用的是GraphDef。雖然GraphDef不會儲存Variable這類Tensor,但是會儲存constant這類Tensor,所以還是可以用來儲存例如weights這些引數的。在Tensorflow 1.3.0版本中提供了一套叫做freeze_graph的工具來自動地將python所描述的Graph中的Variable替換成constant儲存在GraphDef中,並將該Graph匯出為Proto.

(freeze_graph:

https://www.tensorflow.org/extend/tool_developers/

tf.train.writer_graph()/tf.import_graph_def()就是用來進行GraphDef讀寫的API。

可知如果僅僅從GraphDef中是無法得到Variable的。

==========================

MetaGraph:

在GraphDef中無法得到Variable,而通過MetaGraph可以得到。

MetaGraph的官方解釋:一個MetaGraph是由一個計算圖和其相關的元資料構成的。其包含了用於繼續訓練、實施評估和(在已經訓練好的Graph圖上)做前向推斷的資訊。

https://www.tensorflow.org/versions/r1.1/programmers_guide/

MetaGraph在具體實現上,就是一個MetaGraphDef(同樣是由Protocol Buffer來定義的)。其中包含了四種主要的資訊:

MetaInfoDef: 存放了一些元資訊,例如版本和其他使用者資訊;

GraphDef: MetaGraph的核心內容之一;

SaverDef: 圖的Saver資訊,例如最多同時儲存的checkpoint數量,需要儲存的Tensor名字等,但並不儲存Tensor中的實際內容;

CollectionDef: 任何需要特殊注意的python物件,需要特殊的標註以方便import_meta_graph後取回,例如”train_op”,”prediction”等等。

其中著重介紹CollectionDef,其為Collection對應的Protocol Buffer。

集合collection是為了方便使用者對圖中的操作和變數進行管理而被建立的一個概念,通過一個string型別的key來對一組python物件進行命名的集合。這個key可以是Tensorflow在內部定義的一些key,也可以是使用者自定義的名字,但是注意是string型別。它有一點名稱空間的意思,將變數收錄進某一個集合collection中。

Tensorflow內部定義了許多標準的key,全部定義在了tf.GraohKeys這個類當中。其中有一些是常用的,tf.GraphKeys.TRAINABLE_VARIABLES, tf.GraphKeys.GLOBAL_VARIABLES等等。tf.trainable_variables()和tf.get_collection(tf.GraphKeys.TRAINABLE_VARIABLES)是等價的;tf.global_variables()和tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)是等價的。

對於使用者定義的key:

pred = model_network(X)

loss = tf.readuce_mean(…, pred, …)

train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

對於這一段對訓練過程定義的程式碼,使用者希望特別關注pred, loss, train_op這幾個操作,那麼就可以使用如下程式碼,將這幾個變數加入集合collection中去。令這個集合名為”training_collection”:

tf.add_to_collection(‘training_collection’, pred)

tf.add_to_collection(‘training_collection’, loss)

tf.add_to_collection(‘training_collection’, train_op)

並且可以通過Train_collect = tf.get_collection(‘training_collection’)得到一個python的list,list中的元素就是加入集合的幾個變數pred, loss, train_op。這通常是為了在一個新的Session中開啟這張Graph時,方便我們獲取想要的操作節點Operation。例如可以通過tf.get_collection()得到train_op,然後通過sess.run(train_op)來進行訓練,而無需重新構建loss和Optimizer。

通過tf.export_meta_graph()儲存Graph,得到MetaGraph,並通過add_to_collection()將操作Operation加入collection中:

with tf.Session() as sess:

pred = model_network(X)

loss = tf.readuce_mean(…, pred, …)

train_op = tf.train.AdamOptimizer(learning_rate).minimize(loss)

tf.add_to_collection(‘training_collection’, train_op)

Meta_graph_def = tf.train.export_meta_graph(tf.get_default_graph(), ‘my_graph.meta’)

通過import_meta_graph將MetaGraph恢復,同時初始化為本Session的default Graph,並通過get_collection重新獲得train_op,以及通過train_op開始一段訓練(sess.run())。

從MetaGraph中恢復構建的圖Graph是可以被訓練的。

https://www.tensorflow.org/api_guides/python/meta_graph

需要特殊說明的是,MetaGraph中雖然包含Variable的資訊,但是沒有Variable的實際值。所以從MetaGraph中恢復的圖Graph,訓練都是從隨機初始化的值開始的,訓練中的Variable 的實際值都儲存在checkpoint檔案中,如果要從之前訓練的狀態繼續恢復訓練,就需要從checkpoint中restore。

tf.export_meta_graph()/tf.import_meta_graph()即為用來進行MetaGraph讀寫的API。tf.train.saver.save()在儲存checkpoint的同時也會儲存MetaGraph,但是在恢復圖時,tf.train.saver.restore()只恢復Variable。如果要從MetaGraph中恢復圖Graph,需要使用tf.import_meta_graph()。這其實是為了方便使用者,因為有時我們不需要從MetaGraph中恢復圖Graph,而僅僅需要在python中構建NN的Graph,並恢復對應的Variable。

==========================

CheckPoint:

CheckPoint中全面儲存了訓練某時間截面的資訊,包括引數、超引數、梯度等等。tf.train.Saver()/tf.saver.restore()則能夠完整地儲存和恢復神經網路的訓練。CheckPoint分為兩個檔案儲存Variable的二進位制資訊:ckpt檔案儲存了Variable的二進位制資訊,index檔案用於儲存ckpt檔案中對應Variable的偏移量資訊。

==========================

總結:

Tensorflow三種API所儲存和恢復的圖Graph是不一樣的。這三種圖是從Tensorflow框架設計的角度出發定義的。簡而言之,Tensorflow在前段python中構建圖Graph,並且通過將該圖序列化到Protocol Buffer得到GraphDef,以方便在後端執行。在這個過程中,圖的儲存、恢復、執行都通過ProtoBuf來實現。GraphDef、MetaGraph以及Variable、Collection、Saver等都有對應的ProtoBuf定義。ProtoBuf的定義也決定了使用者能對圖進行的操作。例如使用者只能找到Node的前一個Node,卻無法得知自己的輸出會被哪個Node接受。