1. 程式人生 > >tensorflow 讀取模型並進行預測

tensorflow 讀取模型並進行預測

tensorflow 讀取兩種格式的模型並進行預測

文章目錄

1. 模型儲存

1.1 checkpoint 模型

如圖所示,
.meta – 儲存圖結構,即神經網路的網路結構
.data

– 儲存資料檔案,即網路的權值,偏置,操作等等
.index – 是一個不可變得字串表,每一個鍵都是張量的名稱,它的值是一個序列化的BundleEntryProto。 每個BundleEntryProto描述張量的元資料:“資料”檔案中的哪個檔案包含張量的內容,該檔案的偏移量,校驗和,一些輔助資料等等。
checkpoint – 文字檔案,裡面記錄了儲存的最新的checkpoint檔案以及其它checkpoint檔案列表。在inference時,可以通過修改這個檔案,指定使用哪個model.

儲存模型:

saver = tf.train.Saver()
saver.save(sess,
model_path)

其中model_path是模型儲存路徑。

1.2 frozen_graph模型

在工程中,我們往往需要將模型和權重固化,便於釋出和預測。
使用tensorFlow官方提供的freeze_graph.py工具來儲存相應模型。(程式碼中把freeze_graph.py檔案放在commom.utils.tf路徑下匯入)

freeze_graph.py先載入模型檔案,從checkpoint檔案讀取權重資料初始化到模型裡的權重變數,再將權重變數轉換成權重常量,然後再通過指定的輸出節點將沒用於輸出推理的Op節點從圖中剝離掉,再重新儲存到指定的檔案裡(用write_graphdef或Saver)。

from tensorflow.core.protobuf import saver_pb2
from common.utils.tf import freeze_graph
# save model graph
tf.train.write_graph(
    sess.graph.as_graph_def(),
    os.path.join(model_path),
    GRAPH_PB_NAME,
    as_text=False)
# generate frozen graph
freeze_graph.freeze_graph(
    input_graph=os.path.join(model_path, GRAPH_PB_NAME),
    input_saver=False,
    input_binary=True,
    input_checkpoint=os.path.join(model_path, CHECKPOINT_PREFIX),
    output_node_names="viterbi_sequence,intent_prediction,intent_probs",
    restore_op_name=None,
    filename_tensor_name=None,
    output_graph=os.path.join(model_path, FROZEN_GRAPH_PB_NAME),
    clear_devices=False,
    initializer_nodes="",
    variable_names_whitelist="",
    variable_names_blacklist="",
    input_meta_graph=None,
    input_saved_model_dir=None,
    saved_model_tags=tf.saved_model.tag_constants.SERVING,
    checkpoint_version=saver_pb2.SaverDef.V2)

其中model_path是模型儲存路徑,GRAPH_PB_NAME定義了圖模型的名字。

freeze_graph主要引數(參考[4]部落格中的引數說明):

  • input_graph : 模型檔案,可以是二進位制的pb檔案,或文字的meta檔案,用input_binary來指定區分。
  • input_checkpoint : 檢查點資料檔案。
  • output_node_names : 輸出節點的名字,有多個時用逗號分開。
  • output_graph : 儲存整合後的輸出模型。

2. 讀取ckpt模型

在我的模型中,要求的輸入有四個,分別是inputs_vocab,inputs_feature_list,sequence_length,max_length
計算得到的輸出有兩個viterbi_sequenceintent_prediction

ckpt = tf.train.get_checkpoint_state(arg.model_path + '/')
saver = tf.train.import_meta_graph(ckpt.model_checkpoint_path + '.meta')    
with tf.Session() as sess:
    saver.restore(sess, ckpt.model_checkpoint_path)
    graph = tf.get_default_graph()   
    # 載入模型中的操作節點	
    inputs_vocab = graph.get_operation_by_name('inputs_vocab').outputs[0]
    feature_data_list = graph.get_operation_by_name('inputs_feature_list').outputs[0]
    sequence_length = graph.get_operation_by_name('sequence_length').outputs[0]
    max_length = graph.get_operation_by_name('max_length').outputs[0]
    # 準備測試資料(略)
    # in_data = ...
    # fea_data_list = ...
    # length = ...
    # max_len = ...
    # feed 資料
    feed_dict = {inputs_vocab.name: in_data,
                 feature_data_list.name: fea_data_list,
                 sequence_length.name: length,
                 max_length.name: max_len}  
    # 計算
    viterbi_sequence = sess.run('viterbi_sequence:0', feed_dict)
    intent_prediction = sess.run('intent_prediction:0', feed_dict)      

3. 讀取frozen graph模型

# 讀取圖檔案
with tf.gfile.FastGFile('./model/frozen_graph.pb', 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    # We load the graph_def in the default graph
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(
            graph_def,
            input_map=None,
            return_elements=None,
            name="",
            op_dict=None,
            producer_op_list=None
        )
        with tf.Session() as sess:
            # 根據名稱返回tensor資料
            inputs_vocab = graph.get_tensor_by_name('inputs_vocab:0')
            feature_data_list = graph.get_tensor_by_name('inputs_feature_list:0')
            sequence_length = graph.get_tensor_by_name('sequence_length:0')
            max_length = graph.get_tensor_by_name('max_length:0')
            # 準備測試資料(略)
            # in_data = ...
            # fea_data_list = ...
            # length = ...
            # max_len = ...
            # feed 資料
            feed_dict = {inputs_vocab.name: in_data,
                         feature_data_list.name: fea_data_list,
                         sequence_length.name: length,
                         max_length.name: max_len}
            # 計算結果
            viterbi_sequence = graph.get_tensor_by_name('viterbi_sequence:0')
            intent_prediction = graph.get_tensor_by_name('intent_prediction:0')
            viterbi_sequence = sess.run(viterbi_sequence, feed_dict)
            intent_prediction = sess.run(intent_prediction, feed_dict)

注意,這裡如果不使用上下文管理器Graph().as_default(),在進行預測的時候可能會報"The Session graph is empty. Add operations to the graph before calling run()…"的錯誤。

參考部落格:

  1. tensorflow將訓練好的模型freeze,即將權重固化到圖裡面,並使用該模型進行預測
  2. 使用TensorFlow C++ API構建線上預測服務
  3. Tensorflow載入預訓練模型和儲存模型
  4. tensorflow,使用freeze_graph.py將模型檔案和權重資料整合在一起並去除無關的Op