從Tensorflow模型檔案中解析並顯示網路結構圖(CKPT模型篇)
上一篇文章《從Tensorflow模型檔案中解析並顯示網路結構圖(pb模型篇)》中介紹瞭如何從pb
模型檔案中提取網路結構圖並實現視覺化,本文介紹如何從CKPT
模型檔案中提取網路結構圖並實現視覺化。理論上,既然能從pb
模型檔案中提取網路結構圖,CKPT
模型檔案自然也不是問題,但是其中會有一些問題。
1 解析CKPT網路結構
解析CKPT
網路結構的第一步是讀取CKPT
模型中的圖檔案,得到圖的Graph
物件後即可得到完整的網路結構。讀取圖檔案示例程式碼如下所示。
saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices= True)
graph = tf.get_default_graph()
with tf.Session( graph=graph) as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess,ckpt_path)
呼叫graph.get_operations()
後即可得到當前圖的所有計算節點,在利用Operation
物件與Tensor
物件之間的相互引用關係即可推斷網路結構。但是需要注意的是,從meta
檔案中匯入的圖中獲取計算節點存在如下問題。
- 包含反向梯度下降計算的所有節點
- 某些計算節點是按基礎計算(加減乘除等)節點拆分成多個計算節點的,如
BatchNorm
,但其實是可以直接合併成一個節點的。
pb
模型檔案可以避免上面第一個問題,將CKPT
模型轉pb
模型後,可以自動將反向梯度下降相關計算節點移除。對於第二點,pb
模型檔案會自動將基礎計算組成一個計算節點,但是對於Tensor操作的函式如Slice等函式是無法合併的。因此,對於第2個問題,將CKPT
模型轉pb
模型後,可以減少這類問題,但是無法避免。徹底避免的方法只能通過自己針對性地實現。經過以上分析,得出的結論是非常有必要將CKPT
模型轉pb
模型。
2 自動將CKPT轉pb,並提取網路圖中節點
def read_graph_from_ckpt (ckpt_path,input_names,output_name ):
saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True)
graph = tf.get_default_graph()
with tf.Session( graph=graph) as sess:
sess.run(tf.global_variables_initializer())
saver.restore(sess,ckpt_path)
output_tf =graph.get_tensor_by_name(output_name)
pb_graph = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [output_tf.op.name])
with tf.Graph().as_default() as g:
tf.import_graph_def(pb_graph, name='')
with tf.Session(graph=g) as sess:
OPS=get_ops_from_pb(g,input_names,output_name)
return OPS
3 測試
以《MobileNet V1官方預訓練模型的使用》文中介紹的MobileNet V1網路結構為例,下載MobileNet_v1_1.0_192檔案並壓縮後,得到mobilenet_v1_1.0_192.ckpt.data-00000-of-00001
、mobilenet_v1_1.0_192.ckpt.index
、mobilenet_v1_1.0_192.ckpt.meta
檔案。我們還需要知道mobilenet_v1_1.0_192.ckpt
模型對應的輸入和輸出Tensor
物件的名稱,官方提供的壓縮包檔案中並沒有告知。一種方法是執行官方程式碼,把輸入Tensor的名稱打印出來。但是執行官方程式碼本身就需要一定的時間和精力,在在上一篇文章《從Tensorflow模型檔案中解析並顯示網路結構圖(pb模型篇)》的程式碼實現中已經實現了將原始網路結構對應的字串寫入到ori_network.txt
檔案中。因此,可以先隨意填寫輸入名稱和輸出名稱,待生成ori_network.txt
檔案後,從檔案中可以直觀看到原始網路結構。ori_network.txt
檔案部分內容如下所示。
通過該檔案可知,輸入Tensor
的名稱為:batch:0
,輸出Tensor
名稱為:MobilenetV1/Predictions/Reshape_1:0
。有了這些資訊後,呼叫函式read_graph_from_ckpt
得到靜態圖的節點列表物件ops
,呼叫函式gen_graph(ops,"save/path/graph.html")
後,在目錄save/path
中得到graph.html
檔案,開啟graph.html
後,顯示結果如下。