1. 程式人生 > 程式設計 >tensorflow使用freeze_graph.py將ckpt轉為pb檔案的方法

tensorflow使用freeze_graph.py將ckpt轉為pb檔案的方法

廢話少說直接上程式碼樣例如下

import tensorflow as tf
import os
from tensorflow.python.tools import freeze_graph
# 本來這個model本無需解釋太多,但是這麼多人不能耐下心來看,那麼我簡單的說一下吧
# network是你們自己定義的模型結構而已
# ps:
# def network(input):
# return tf.layers.max_pooling2d(input,2,2)
from model import network


os.environ['CUDA_VISIBLE_DEVICES']='2' #設定GPU


model_path = "path to /model.ckpt-0000" #設定model的路徑,因新版tensorflow會生成三個檔案,只需寫到數字前


def main():

 tf.reset_default_graph()

 input_node = tf.placeholder(tf.float32,shape=(228,304,3)) #這個是你送入網路的圖片大小,如果你是其他的大小自行修改
 input_node = tf.expand_dims(input_node,0)
 flow = network(input_node)
 flow = tf.cast(flow,tf.uint8,'out') #設定輸出型別以及輸出的介面名字,為了之後的呼叫pb的時候使用

 saver = tf.train.Saver()
 with tf.Session() as sess:

  saver.restore(sess,model_path)

  #儲存圖
  tf.train.write_graph(sess.graph_def,'output_model/pb_model','model.pb')
  #把圖和引數結構一起
  freeze_graph.freeze_graph('output_model/pb_model/model.pb','',False,model_path,'out','save/restore_all','save/Const:0','output_model/pb_model/frozen_model.pb',"")

 print("done")

if __name__ == '__main__':
 main()

這節是關於tensorflow的Freezing,字面意思是冷凍,可理解為整合合併;整合什麼呢,就是將模型檔案和權重檔案整合合併為一個檔案,主要用途是便於釋出。

官方解釋可參考:https://www.tensorflow.org/extend/tool_developers/#freezing

這裡我按我的理解翻譯下,不對的地方請指正:
有一點令我們為比較困惑的是,tensorflow在訓練過程中,通常不會將權重資料儲存的格式檔案裡(這裡我理解是模型檔案),反而是分開儲存在一個叫checkpoint的檢查點檔案裡,當初始化時,再通過模型檔案裡的變數Op節點來從checkoupoint檔案讀取資料並初始化變數。這種模型和權重資料分開儲存的情況,使得釋出產品時不是那麼方便,所以便有了freeze_graph.py指令碼檔案用來將這兩檔案整合合併成一個檔案。

freeze_graph.py是怎麼做的呢?首行它先載入模型檔案,再從checkpoint檔案讀取權重資料初始化到模型裡的權重變數,再將權重變數轉換成權重 常量 (因為 常量 能隨模型一起儲存在同一個檔案裡),然後再通過指定的輸出節點將沒用於輸出推理的Op節點從圖中剝離掉,再重新儲存到指定的檔案裡(用write_graphdef或Saver)

檔案目錄:tensorflow/python/tools/free_graph.py
測試檔案:tensorflow/python/tools/free_graph_test.py 這個測試檔案很有學習價值

引數:

總共有11個引數,一個個介紹下(必選: 表示必須有值;可選: 表示可以為空):

1、input_graph:(必選)模型檔案,可以是二進位制的pb檔案,或文字的meta檔案,用input_binary來指定區分(見下面說明)
2、input_saver:(可選)Saver解析器。儲存模型和許可權時,Saver也可以自身序列化儲存,以便在載入時應用合適的版本。主要用於版本不相容時使用。可以為空,為空時用當前版本的Saver。
3、input_binary:(可選)配合input_graph用,為true時,input_graph為二進位制,為false時,input_graph為檔案。預設False
4、input_checkpoint:(必選)檢查點資料檔案。訓練時,給Saver用於儲存權重、偏置等變數值。這時用於模型恢復變數值。
5、output_node_names:(必選)輸出節點的名字,有多個時用逗號分開。用於指定輸出節點,將沒有在輸出線上的其它節點剔除。
6、restore_op_name:(可選)從模型恢復節點的名字。升級版中已棄用。預設:save/restore_all
7、filename_tensor_name:(可選)已棄用。預設:save/Const:0
8、output_graph:(必選)用來儲存整合後的模型輸出檔案。
9、clear_devices:(可選),預設True。指定是否清除訓練時節點指定的運算裝置(如cpu、gpu、tpu。cpu是預設)
10、initializer_nodes:(可選)預設空。許可權載入後,可通過此引數來指定需要初始化的節點,用逗號分隔多個節點名字。
11、variable_names_blacklist:(可先)預設空。變數黑名單,用於指定不用恢復值的變數,用逗號分隔多個變數名字。

用法:

例:python tensorflow/python/tools/free_graph.py \
–input_graph=some_graph_def.pb \ 注意:這裡的pb檔案是用tf.train.write_graph方法儲存的
–input_checkpoint=model.ckpt.1001 \ 注意:這裡若是r12以上的版本,只需給.data-00000….前面的檔名,如:model.ckpt.1001.data-00000-of-00001,只需寫model.ckpt.1001
–output_graph=/tmp/frozen_graph.pb
–output_node_names=softmax

另外,如果模型檔案是.meta格式的,也就是說用saver.Save方法和checkpoint一起生成的元模型檔案,free_graph.py不適用,但可以改造下:
1、copy free_graph.py為free_graph_meta.py
2、修改free_graph.py,匯入meta_graph:from tensorflow.python.framework import meta_graph
3、將91行到97行換成:input_graph_def = meta_graph.read_meta_graph_file(input_graph).graph_def

這樣改即可載入meta檔案

到此這篇關於tensorflow使用freeze_graph.py將ckpt轉為pb檔案的方法的文章就介紹到這了,更多相關tensorflow ckpt轉為pb檔案內容請搜尋我們以前的文章或繼續瀏覽下面的相關文章希望大家以後多多支援我們!