1. 程式人生 > >從Tensorflow模型檔案中解析並顯示網路結構圖(CKPT模型篇)

從Tensorflow模型檔案中解析並顯示網路結構圖(CKPT模型篇)

上一篇文章《從Tensorflow模型檔案中解析並顯示網路結構圖(pb模型篇)》中介紹瞭如何從pb模型檔案中提取網路結構圖並實現視覺化,本文介紹如何從CKPT模型檔案中提取網路結構圖並實現視覺化。理論上,既然能從pb模型檔案中提取網路結構圖,CKPT模型檔案自然也不是問題,但是其中會有一些問題。

1 解析CKPT網路結構

解析CKPT網路結構的第一步是讀取CKPT模型中的圖檔案,得到圖的Graph物件後即可得到完整的網路結構。讀取圖檔案示例程式碼如下所示。

    saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=
True) graph = tf.get_default_graph() with tf.Session( graph=graph) as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess,ckpt_path)

呼叫graph.get_operations()後即可得到當前圖的所有計算節點,在利用Operation物件與Tensor物件之間的相互引用關係即可推斷網路結構。但是需要注意的是,從meta檔案中匯入的圖中獲取計算節點存在如下問題。

  1. 包含反向梯度下降計算的所有節點
  2. 某些計算節點是按基礎計算(加減乘除等)節點拆分成多個計算節點的,如BatchNorm,但其實是可以直接合併成一個節點的。

pb模型檔案可以避免上面第一個問題,將CKPT模型轉pb模型後,可以自動將反向梯度下降相關計算節點移除。對於第二點,pb模型檔案會自動將基礎計算組成一個計算節點,但是對於Tensor操作的函式如Slice等函式是無法合併的。因此,對於第2個問題,將CKPT模型轉pb模型後,可以減少這類問題,但是無法避免。徹底避免的方法只能通過自己針對性地實現。經過以上分析,得出的結論是非常有必要將CKPT模型轉pb模型。

2 自動將CKPT轉pb,並提取網路圖中節點

def read_graph_from_ckpt
(ckpt_path,input_names,output_name ): saver = tf.train.import_meta_graph(ckpt_path+'.meta',clear_devices=True) graph = tf.get_default_graph() with tf.Session( graph=graph) as sess: sess.run(tf.global_variables_initializer()) saver.restore(sess,ckpt_path) output_tf =graph.get_tensor_by_name(output_name) pb_graph = tf.graph_util.convert_variables_to_constants( sess, graph.as_graph_def(), [output_tf.op.name]) with tf.Graph().as_default() as g: tf.import_graph_def(pb_graph, name='') with tf.Session(graph=g) as sess: OPS=get_ops_from_pb(g,input_names,output_name) return OPS

3 測試

《MobileNet V1官方預訓練模型的使用》文中介紹的MobileNet V1網路結構為例,下載MobileNet_v1_1.0_192檔案並壓縮後,得到mobilenet_v1_1.0_192.ckpt.data-00000-of-00001mobilenet_v1_1.0_192.ckpt.indexmobilenet_v1_1.0_192.ckpt.meta檔案。我們還需要知道mobilenet_v1_1.0_192.ckpt模型對應的輸入和輸出Tensor物件的名稱,官方提供的壓縮包檔案中並沒有告知。一種方法是執行官方程式碼,把輸入Tensor的名稱打印出來。但是執行官方程式碼本身就需要一定的時間和精力,在在上一篇文章《從Tensorflow模型檔案中解析並顯示網路結構圖(pb模型篇)》的程式碼實現中已經實現了將原始網路結構對應的字串寫入到ori_network.txt檔案中。因此,可以先隨意填寫輸入名稱和輸出名稱,待生成ori_network.txt檔案後,從檔案中可以直觀看到原始網路結構。ori_network.txt檔案部分內容如下所示。

ori_network.txt檔案部分內容 通過該檔案可知,輸入Tensor的名稱為:batch:0,輸出Tensor名稱為:MobilenetV1/Predictions/Reshape_1:0。有了這些資訊後,呼叫函式read_graph_from_ckpt得到靜態圖的節點列表物件ops,呼叫函式gen_graph(ops,"save/path/graph.html")後,在目錄save/path中得到graph.html檔案,開啟graph.html後,顯示結果如下。

讀取並顯示CKPT模型的圖結構

4 原始碼地址