1. 程式人生 > >TensorFlow模型格式簡介

TensorFlow模型格式簡介

簡介

TensorFlow的模型格式有很多種,針對不同場景可以使用不同的格式,只要符合規範的模型都可以輕易部署到線上服務或移動裝置上,這裡簡單列舉一下。

  • Checkpoint: 用於儲存模型的權重,主要用於模型訓練過程中引數的備份和模型訓練熱啟動。
  • GraphDef:用於儲存模型的Graph,不包含模型權重,加上checkpoint後就有模型上線的全部資訊。
  • ExportModel:使用exportor介面匯出的模型檔案,包含模型Graph和權重可直接用於上線,但官方已經標記為deprecated推薦使用SavedModel。
  • SavedModel:使用saved_model介面匯出的模型檔案,包含模型Graph和許可權可直接用於上線,TensorFlow和Keras模型推薦使用這種模型格式。
  • FrozenGraph:使用freeze_graph.py對checkpoint和GraphDef進行整合和優化,可以直接部署到Android、iOS等移動裝置上。
  • TFLite:基於flatbuf對模型進行優化,可以直接部署到Android、iOS等移動裝置上,使用介面和FrozenGraph有些差異。

模型格式

目前建議TensorFlow和Keras模型都匯出成SavedModel格式,這樣就可以直接使用通用的TensorFlow Serving服務,模型匯出即可上線不需要改任何程式碼。不同的模型匯出時只要指定輸入和輸出的signature即可,其中字串的key可以任意命名只會在客戶端請求時用到,可以參考下面的程式碼示例。

注意,目前使用tf.py_func()的模型匯出後不能直接上線,模型的所有結構建議都用op實現。

TensorFlow模型匯出

import os
import tensorflow as tf
from tensorflow.python.saved_model import builder as saved_model_builder
from tensorflow.python.saved_model import (
    signature_constants, signature_def_utils, tag_constants, utils)
from tensorflow.python.util import compat

model_path = "model"
model_version = 1
model_signature = signature_def_utils.build_signature_def(
    inputs={
        "keys": utils.build_tensor_info(keys_placeholder),
        "features": utils.build_tensor_info(inference_features)
    },
    outputs={
        "keys": utils.build_tensor_info(keys_identity),
        "prediction": utils.build_tensor_info(inference_op),
        "softmax": utils.build_tensor_info(inference_softmax),
    },
    method_name=signature_constants.PREDICT_METHOD_NAME)
export_path = os.path.join(compat.as_bytes(model_path), compat.as_bytes(str(model_version)))
legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op')
 
builder = saved_model_builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
    sess, [tag_constants.SERVING],
    clear_devices=True,
    signature_def_map={
        signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
        model_signature,
    },
    legacy_init_op=legacy_init_op)
 
builder.save() 

Keras模型匯出

import os
import tensorflow as tf
from tensorflow.python.util import compat
 
def export_savedmodel(model):
  model_path = "model"
  model_version = 1
  model_signature = tf.saved_model.signature_def_utils.predict_signature_def(
      inputs={'input': model.input}, outputs={'output': model.output})
  export_path = os.path.join(compat.as_bytes(model_path), compat.as_bytes(str(model_version)))
 
  builder = tf.saved_model.builder.SavedModelBuilder(export_path)
  builder.add_meta_graph_and_variables(
      sess=K.get_session(),
      tags=[tf.saved_model.tag_constants.SERVING],
      clear_devices=True,
      signature_def_map={
          tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
          model_signature
      })
  builder.save()

SavedModel模型結構

使用TensorFlow的API匯出SavedModel模型後,可以檢查模型的目錄結構如下,然後就可以直接使用開源工具來載入服務了。

模型上線

部署線上服務

部署離線裝置

待更新