1. 程式人生 > 程式設計 >淺談tensorflow模型儲存為pb的各種姿勢

淺談tensorflow模型儲存為pb的各種姿勢

一,直接儲存pb

1,首先我們當然可以直接在tensorflow訓練中直接儲存為pb為格式,儲存pb的好處就是使用場景是實現建立模型與使用模型的解耦,使得建立模型與使用模型的解耦,使得前向推導inference程式碼統一。另外的好處就是儲存為pb的時候,模型的變數會變成固定的,導致模型的大小會大大減小。

這裡稍稍解釋下pb:是MetaGraph的protocol buffer格式的檔案,MetaGraph包括計算圖,資料流,以及相關的變數和輸入輸出

主要使用tf.SavedModelBuilder來完成這個工作,並且可以把多個計算圖儲存到一個pb檔案中,如果有多個MetaGraph,那麼只會保留第一個MetaGraph的版本號。

保持pb的檔案程式碼:

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
 
pb_file_path = os.getcwd()
 
with tf.Session(graph=tf.Graph()) as sess:
 x = tf.placeholder(tf.int32,name='x')
 y = tf.placeholder(tf.int32,name='y')
 b = tf.Variable(1,name='b')
 xy = tf.multiply(x,y)
 # 這裡的輸出需要加上name屬性
 op = tf.add(xy,b,name='op_to_store')
 
 sess.run(tf.global_variables_initializer())
 
 # convert_variables_to_constants 需要指定output_node_names,list(),可以多個
 constant_graph = graph_util.convert_variables_to_constants(sess,sess.graph_def,['op_to_store'])
 
 # 測試 OP
 feed_dict = {x: 10,y: 3}
 print(sess.run(op,feed_dict))
 
 # 寫入序列化的 PB 檔案
 with tf.gfile.FastGFile(pb_file_path+'model.pb',mode='wb') as f:
  f.write(constant_graph.SerializeToString())
 
 # 輸出
 # INFO:tensorflow:Froze 1 variables.
 # Converted 1 variables to const ops.
 # 31

其實主要是:

 # convert_variables_to_constants 需要指定output_node_names,list(),可以多個
 constant_graph = graph_util.convert_variables_to_constants(sess,['op_to_store'])
 # 寫入序列化的 PB 檔案
 with tf.gfile.FastGFile(pb_file_path+'model.pb',mode='wb') as f:
  f.write(constant_graph.SerializeToString())

1.1 載入測試程式碼

from tensorflow.python.platform import gfile
 
sess = tf.Session()
with gfile.FastGFile(pb_file_path+'model.pb','rb') as f:
 graph_def = tf.GraphDef()
 graph_def.ParseFromString(f.read())
 sess.graph.as_default()
 tf.import_graph_def(graph_def,name='') # 匯入計算圖
 
# 需要有一個初始化的過程 
sess.run(tf.global_variables_initializer())
 
# 需要先復原變數
print(sess.run('b:0'))
# 1
 
# 輸入
input_x = sess.graph.get_tensor_by_name('x:0')
input_y = sess.graph.get_tensor_by_name('y:0')
 
op = sess.graph.get_tensor_by_name('op_to_store:0')
 
ret = sess.run(op,feed_dict={input_x: 5,input_y: 5})
print(ret)
# 輸出 26

2,第二種就是採用上述的那API來進行儲存

import tensorflow as tf
import os
from tensorflow.python.framework import graph_util
 
pb_file_path = os.getcwd()
 
with tf.Session(graph=tf.Graph()) as sess:
 x = tf.placeholder(tf.int32,mode='wb') as f:
  f.write(constant_graph.SerializeToString())
 
 # INFO:tensorflow:Froze 1 variables.
 # Converted 1 variables to const ops.
 # 31
 
 
 # 官網有誤,寫成了 saved_model_builder 
 builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
 # 構造模型儲存的內容,指定要儲存的 session,特定的 tag,# 輸入輸出資訊字典,額外的資訊
 builder.add_meta_graph_and_variables(sess,['cpu_server_1'])
 
# 新增第二個 MetaGraphDef 
#with tf.Session(graph=tf.Graph()) as sess:
# ...
# builder.add_meta_graph([tag_constants.SERVING])
#...
 
builder.save() # 儲存 PB 模型

核心就是採用了:

 # 官網有誤,寫成了 saved_model_builder 
 builder = tf.saved_model.builder.SavedModelBuilder(pb_file_path+'savemodel')
 # 構造模型儲存的內容,指定要儲存的 session,特定的 tag,['cpu_server_1'])

2.1 對應的測試程式碼為:

with tf.Session(graph=tf.Graph()) as sess:
 tf.saved_model.loader.load(sess,['cpu_1'],pb_file_path+'savemodel')
 sess.run(tf.global_variables_initializer())
 
 input_x = sess.graph.get_tensor_by_name('x:0')
 input_y = sess.graph.get_tensor_by_name('y:0')
 
 op = sess.graph.get_tensor_by_name('op_to_store:0')
 
 ret = sess.run(op,input_y: 5})
 print(ret)
# 只需要指定要恢復模型的 session,模型的 tag,模型的儲存路徑即可,使用起來更加簡單

這樣和之前的匯入pb模型一樣,也是要知道tensor的name,那麼如何在不知道tensor name的情況下使用呢,給add_meta_graph_and_variables方法傳入第三個引數,signature_def_map即可。

二,從ckpt進行載入

使用tf.train.saver()保持模型的時候會產生多個檔案,會把計算圖的結構和圖上引數取值分成了不同檔案儲存,這種方法是在TensorFlow中最常用的儲存方式:

import tensorflow as tf
# 宣告兩個變數
v1 = tf.Variable(tf.random_normal([1,2]),name="v1")
v2 = tf.Variable(tf.random_normal([2,3]),name="v2")
init_op = tf.global_variables_initializer() # 初始化全部變數
saver = tf.train.Saver() # 宣告tf.train.Saver類用於儲存模型
with tf.Session() as sess:
 sess.run(init_op)
 print("v1:",sess.run(v1)) # 列印v1、v2的值一會讀取之後對比
 print("v2:",sess.run(v2))
 saver_path = saver.save(sess,"save/model.ckpt") # 將模型儲存到save/model.ckpt檔案
 print("Model saved in file:",saver_path)

淺談tensorflow模型儲存為pb的各種姿勢

checkpoint是檢查點的檔案,檔案儲存了一個目錄下所有的模型檔案列表

model.ckpt.meta檔案儲存了Tensorflow計算圖的結果,可以理解為神經網路的網路結構,該檔案可以被tf.train.import_meta_graph載入到當前預設的圖來使用

ckpt.data是儲存模型中每個變數的取值

方法一, tensorflow提供了convert_variables_to_constants()方法,改方法可以固化模型結構,將計算圖中的變數取值以常量的形式儲存

ckpt轉換pb格式過程如下:

1,通過傳入ckpt模型的路徑得到模型的圖和變數資料

2,通過import_meta_graph匯入模型中的圖

3,通過saver.restore從模型中恢復圖中各個變數的資料

4,通過graph_util.convert_variables_to_constants將模型持久化

import tensorflow as tf 
from tensorflow.python.framework import graph_util
from tensorflow.pyton.platform import gfile
 
def freeze_graph(input_checkpoint,output_graph):
 '''
 :param input_checkpoint:
 :param output_graph: PB模型儲存路徑
 :return:
 '''
 # checkpoint = tf.train.get_checkpoint_state(model_folder) #檢查目錄下ckpt檔案狀態是否可用
 # input_checkpoint = checkpoint.model_checkpoint_path #得ckpt檔案路徑
 
 # 指定輸出的節點名稱,該節點名稱必須是原模型中存在的節點
 output_node_names = "InceptionV3/Logits/SpatialSqueeze"
 saver = tf.train.import_meta_graph(input_checkpoint + '.meta',clear_devices=True)
 graph = tf.get_default_graph() # 獲得預設的圖
 input_graph_def = graph.as_graph_def() # 返回一個序列化的圖代表當前的圖
 
 with tf.Session() as sess:
  saver.restore(sess,input_checkpoint) #恢復圖並得到資料
  output_graph_def = graph_util.convert_variables_to_constants( # 模型持久化,將變數值固定
   sess=sess,input_graph_def=input_graph_def,# 等於:sess.graph_def
   output_node_names=output_node_names.split(","))# 如果有多個輸出節點,以逗號隔開
 
  with tf.gfile.GFile(output_graph,"wb") as f: #儲存模型
   f.write(output_graph_def.SerializeToString()) #序列化輸出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到當前圖有幾個操作節點
 
  # for op in graph.get_operations():
  #  print(op.name,op.values())

函式freeze_graph中,最重要的就是指定輸出節點的名稱,這個節點名稱是原模型存在的結點,注意節點名稱與張量名稱的區別:

如:“input:0”是張量的名稱,而“input”表示的是節點的名稱

原始碼中通過graph = tf.get_default_graph()獲得預設圖,這個圖就是由saver = tf.train.import_meta_graph(input_checkpoint + '.meta',clear_devices=True)恢復的圖,因此就必須執行tf.train.import_meta_graph,再執行tf.get_default_graph()

1.2 一個小工具

tensorflow列印pb模型的所有節點

from tensorflow.python.framework import tensor_util
from google.protobuf import text_format 
import tensorflow as tf 
from tensorflow.python.platform import gfile 
from tensorflow.python.framework import tensor_util
 
pb_path = './model.pb'
 
with tf.Session() as sess:
 with gfile.FastGFile(pb_path,'rb') as f:
  graph_def = tf.GraphDef()
 
  graph_def.ParseFromString(f.read())
  tf.import_graph_def(graph_def,name='')
  for i,n in enumerate(graph_def.node):
   print("Name of the node -%s"%n.name)
tensorflow列印ckpt的所有節點

from tensorflow.python import pywrap_tensorflow
checkpoint_path = './_checkpoint/hed.ckpt-130'
 
reader = pywrap_tensorflow.NewCheckpointReader(checkpoint_path)
var_to_shape_map = reader.get_variable_to_shape_map()
for key in var_to_shape_map:
 print("tensor_name:",key)

方法二,除了上述辦法外還有一種是需要通過原始碼的,這樣既可以得到輸出節點,還可以自定義輸入節點。

import tensorflow as tf 
 
def model(input):
 net = tf.layers.conv2d(input,filters=32,kernel_size=3)
 net = tf.layers.batch_normalization(net,fused=False)
 net = tf.layers.separable_conv2d(net,32,3)
 net = tf.layers.conv2d(net,kernel_size=3,name='output')
 
 return net 
 
input_node = tf.placeholder(tf.float32,[1,480,3],name = 'image')
output_node_names = 'head_neck_count/BiasAdd'
ckpt = ckpt_path 
pb = pb_path 
 
with tf.Session() as sess:
 model1 = model(input_node)
 sess.run(tf.global_variables_initializer())
 output_node_names = 'output/BiasAdd'
 
 input_graph_def = tf.get_default_graph().as_graph_def()
 output_graph_def = tf.graph_util.convert_variables_to_constants(sess,input_graph_def,output_node_names.split(','))
 
with tf.gfile.GFile(pb,'wb') as f:
 f.write(output_graph_def.SerializeToString())

注意:

節點名稱和張量名稱區別

類似於output是節點名稱

類似於output:0是張量名稱

方法三,其實是方法一的延伸可以配合tensorflow自帶的一些工具來進行完成

freeze_graph

總共有11個引數,一個個介紹下(必選: 表示必須有值;可選: 表示可以為空):

1、input_graph:(必選)模型檔案,可以是二進位制的pb檔案,或文字的meta檔案,用input_binary來指定區分(見下面說明)

2、input_saver:(可選)Saver解析器。儲存模型和許可權時,Saver也可以自身序列化儲存,以便在載入時應用合適的版本。主要用於版本不相容時使用。可以為空,為空時用當前版本的Saver。

3、input_binary:(可選)配合input_graph用,為true時,input_graph為二進位制,為false時,input_graph為檔案。預設False

4、input_checkpoint:(必選)檢查點資料檔案。訓練時,給Saver用於儲存權重、偏置等變數值。這時用於模型恢復變數值。

5、output_node_names:(必選)輸出節點的名字,有多個時用逗號分開。用於指定輸出節點,將沒有在輸出線上的其它節點剔除。

6、restore_op_name:(可選)從模型恢復節點的名字。升級版中已棄用。預設:save/restore_all

7、filename_tensor_name:(可選)已棄用。預設:save/Const:0

8、output_graph:(必選)用來儲存整合後的模型輸出檔案。

9、clear_devices:(可選),預設True。指定是否清除訓練時節點指定的運算裝置(如cpu、gpu、tpu。cpu是預設)

10、initializer_nodes:(可選)預設空。許可權載入後,可通過此引數來指定需要初始化的節點,用逗號分隔多個節點名字。

11、variable_names_blacklist:(可先)預設空。變數黑名單,用於指定不用恢復值的變數,用逗號分隔多個變數名字。

所以還是建議選擇方法三

匯出pb後的測試程式碼如下:下圖是比較完成的測試程式碼與匯出程式碼。

# -*-coding: utf-8 -*-
"""
 @Project: tensorflow_models_nets
 @File : convert_pb.py
 @Author : panjq
 @E-mail : [email protected]
 @Date : 2018-08-29 17:46:50
 @info :
 -通過傳入 CKPT 模型的路徑得到模型的圖和變數資料
 -通過 import_meta_graph 匯入模型中的圖
 -通過 saver.restore 從模型中恢復圖中各個變數的資料
 -通過 graph_util.convert_variables_to_constants 將模型持久化
"""
 
import tensorflow as tf
from create_tf_record import *
from tensorflow.python.framework import graph_util
 
resize_height = 299 # 指定圖片高度
resize_width = 299 # 指定圖片寬度
depths = 3
 
def freeze_graph_test(pb_path,image_path):
 '''
 :param pb_path:pb檔案的路徑
 :param image_path:測試圖片的路徑
 :return:
 '''
 with tf.Graph().as_default():
  output_graph_def = tf.GraphDef()
  with open(pb_path,"rb") as f:
   output_graph_def.ParseFromString(f.read())
   tf.import_graph_def(output_graph_def,name="")
  with tf.Session() as sess:
   sess.run(tf.global_variables_initializer())
 
   # 定義輸入的張量名稱,對應網路結構的輸入張量
   # input:0作為輸入影象,keep_prob:0作為dropout的引數,測試時值為1,is_training:0訓練引數
   input_image_tensor = sess.graph.get_tensor_by_name("input:0")
   input_keep_prob_tensor = sess.graph.get_tensor_by_name("keep_prob:0")
   input_is_training_tensor = sess.graph.get_tensor_by_name("is_training:0")
 
   # 定義輸出的張量名稱
   output_tensor_name = sess.graph.get_tensor_by_name("InceptionV3/Logits/SpatialSqueeze:0")
 
   # 讀取測試圖片
   im=read_image(image_path,resize_height,resize_width,normalization=True)
   im=im[np.newaxis,:]
   # 測試讀出來的模型是否正確,注意這裡傳入的是輸出和輸入節點的tensor的名字,不是操作節點的名字
   # out=sess.run("InceptionV3/Logits/SpatialSqueeze:0",feed_dict={'input:0': im,'keep_prob:0':1.0,'is_training:0':False})
   out=sess.run(output_tensor_name,feed_dict={input_image_tensor: im,input_keep_prob_tensor:1.0,input_is_training_tensor:False})
   print("out:{}".format(out))
   score = tf.nn.softmax(out,name='pre')
   class_id = tf.argmax(score,1)
   print "pre class_id:{}".format(sess.run(class_id))
 
def freeze_graph(input_checkpoint,clear_devices=True)
 
 with tf.Session() as sess:
  saver.restore(sess,input_graph_def=sess.graph_def,"wb") as f: #儲存模型
   f.write(output_graph_def.SerializeToString()) #序列化輸出
  print("%d ops in the final graph." % len(output_graph_def.node)) #得到當前圖有幾個操作節點
 
  # for op in sess.graph.get_operations():
  #  print(op.name,op.values())
 
def freeze_graph2(input_checkpoint,op.values())
 
if __name__ == '__main__':
 # 輸入ckpt模型路徑
 input_checkpoint='models/model.ckpt-10000'
 # 輸出pb模型的路徑
 out_pb_path="models/pb/frozen_model.pb"
 # 呼叫freeze_graph將ckpt轉為pb
 freeze_graph(input_checkpoint,out_pb_path)
 
 # 測試pb模型
 image_path = 'test_image/animal.jpg'
 freeze_graph_test(pb_path=out_pb_path,image_path=image_path)

以上這篇淺談tensorflow模型儲存為pb的各種姿勢就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。