tensorflow 讀取模型並進行預測
阿新 • • 發佈:2018-12-28
tensorflow 讀取兩種格式的模型並進行預測
文章目錄
1. 模型儲存
1.1 checkpoint 模型
如圖所示,
.meta
– 儲存圖結構,即神經網路的網路結構
.data
.index
– 是一個不可變得字串表,每一個鍵都是張量的名稱,它的值是一個序列化的BundleEntryProto。 每個BundleEntryProto描述張量的元資料:“資料”檔案中的哪個檔案包含張量的內容,該檔案的偏移量,校驗和,一些輔助資料等等。checkpoint
– 文字檔案,裡面記錄了儲存的最新的checkpoint檔案以及其它checkpoint檔案列表。在inference時,可以通過修改這個檔案,指定使用哪個model.
儲存模型:
saver = tf.train.Saver()
saver.save(sess, model_path)
其中model_path
是模型儲存路徑。
1.2 frozen_graph模型
在工程中,我們往往需要將模型和權重固化,便於釋出和預測。
使用tensorFlow
官方提供的freeze_graph.py
工具來儲存相應模型。(程式碼中把freeze_graph.py
檔案放在commom.utils.tf
路徑下匯入)
freeze_graph.py
先載入模型檔案,從checkpoint檔案讀取權重資料初始化到模型裡的權重變數,再將權重變數轉換成權重常量,然後再通過指定的輸出節點將沒用於輸出推理的Op節點從圖中剝離掉,再重新儲存到指定的檔案裡(用write_graphdef或Saver)。
from tensorflow.core.protobuf import saver_pb2
from common.utils.tf import freeze_graph
# save model graph
tf.train.write_graph(
sess.graph.as_graph_def(),
os.path.join(model_path),
GRAPH_PB_NAME,
as_text=False)
# generate frozen graph
freeze_graph.freeze_graph(
input_graph=os.path.join(model_path, GRAPH_PB_NAME),
input_saver=False,
input_binary=True,
input_checkpoint=os.path.join(model_path, CHECKPOINT_PREFIX),
output_node_names="viterbi_sequence,intent_prediction,intent_probs",
restore_op_name=None,
filename_tensor_name=None,
output_graph=os.path.join(model_path, FROZEN_GRAPH_PB_NAME),
clear_devices=False,
initializer_nodes="",
variable_names_whitelist="",
variable_names_blacklist="",
input_meta_graph=None,
input_saved_model_dir=None,
saved_model_tags=tf.saved_model.tag_constants.SERVING,
checkpoint_version=saver_pb2.SaverDef.V2)
其中model_path
是模型儲存路徑,GRAPH_PB_NAME
定義了圖模型的名字。
freeze_graph主要引數(參考[4]部落格中的引數說明):
input_graph
: 模型檔案,可以是二進位制的pb檔案,或文字的meta檔案,用input_binary
來指定區分。input_checkpoint
: 檢查點資料檔案。output_node_names
: 輸出節點的名字,有多個時用逗號分開。output_graph
: 儲存整合後的輸出模型。
2. 讀取ckpt模型
在我的模型中,要求的輸入有四個,分別是inputs_vocab
,inputs_feature_list
,sequence_length
,max_length
。
計算得到的輸出有兩個viterbi_sequence
和intent_prediction
。
ckpt = tf.train.get_checkpoint_state(arg.model_path + '/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')
with tf.Session() as sess:
saver.restore(sess, ckpt.model_checkpoint_path)
graph = tf.get_default_graph()
# 載入模型中的操作節點
inputs_vocab = graph.get_operation_by_name('inputs_vocab').outputs[0]
feature_data_list = graph.get_operation_by_name('inputs_feature_list').outputs[0]
sequence_length = graph.get_operation_by_name('sequence_length').outputs[0]
max_length = graph.get_operation_by_name('max_length').outputs[0]
# 準備測試資料(略)
# in_data = ...
# fea_data_list = ...
# length = ...
# max_len = ...
# feed 資料
feed_dict = {inputs_vocab.name: in_data,
feature_data_list.name: fea_data_list,
sequence_length.name: length,
max_length.name: max_len}
# 計算
viterbi_sequence = sess.run('viterbi_sequence:0', feed_dict)
intent_prediction = sess.run('intent_prediction:0', feed_dict)
3. 讀取frozen graph模型
# 讀取圖檔案
with tf.gfile.FastGFile('./model/frozen_graph.pb', 'rb') as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
# We load the graph_def in the default graph
with tf.Graph().as_default() as graph:
tf.import_graph_def(
graph_def,
input_map=None,
return_elements=None,
name="",
op_dict=None,
producer_op_list=None
)
with tf.Session() as sess:
# 根據名稱返回tensor資料
inputs_vocab = graph.get_tensor_by_name('inputs_vocab:0')
feature_data_list = graph.get_tensor_by_name('inputs_feature_list:0')
sequence_length = graph.get_tensor_by_name('sequence_length:0')
max_length = graph.get_tensor_by_name('max_length:0')
# 準備測試資料(略)
# in_data = ...
# fea_data_list = ...
# length = ...
# max_len = ...
# feed 資料
feed_dict = {inputs_vocab.name: in_data,
feature_data_list.name: fea_data_list,
sequence_length.name: length,
max_length.name: max_len}
# 計算結果
viterbi_sequence = graph.get_tensor_by_name('viterbi_sequence:0')
intent_prediction = graph.get_tensor_by_name('intent_prediction:0')
viterbi_sequence = sess.run(viterbi_sequence, feed_dict)
intent_prediction = sess.run(intent_prediction, feed_dict)
注意,這裡如果不使用上下文管理器Graph().as_default()
,在進行預測的時候可能會報"The Session graph is empty. Add operations to the graph before calling run()…"的錯誤。