TensorFlow模型格式簡介
阿新 • • 發佈:2018-12-22
簡介
TensorFlow的模型格式有很多種,針對不同場景可以使用不同的格式,只要符合規範的模型都可以輕易部署到線上服務或移動裝置上,這裡簡單列舉一下。
- Checkpoint: 用於儲存模型的權重,主要用於模型訓練過程中引數的備份和模型訓練熱啟動。
- GraphDef:用於儲存模型的Graph,不包含模型權重,加上checkpoint後就有模型上線的全部資訊。
- ExportModel:使用exportor介面匯出的模型檔案,包含模型Graph和權重可直接用於上線,但官方已經標記為deprecated推薦使用SavedModel。
- SavedModel:使用saved_model介面匯出的模型檔案,包含模型Graph和許可權可直接用於上線,TensorFlow和Keras模型推薦使用這種模型格式。
- FrozenGraph:使用freeze_graph.py對checkpoint和GraphDef進行整合和優化,可以直接部署到Android、iOS等移動裝置上。
- TFLite:基於flatbuf對模型進行優化,可以直接部署到Android、iOS等移動裝置上,使用介面和FrozenGraph有些差異。
模型格式
目前建議TensorFlow和Keras模型都匯出成SavedModel格式,這樣就可以直接使用通用的TensorFlow Serving服務,模型匯出即可上線不需要改任何程式碼。不同的模型匯出時只要指定輸入和輸出的signature即可,其中字串的key可以任意命名只會在客戶端請求時用到,可以參考下面的程式碼示例。
注意,目前使用tf.py_func()的模型匯出後不能直接上線,模型的所有結構建議都用op實現。
TensorFlow模型匯出
import os import tensorflow as tf from tensorflow.python.saved_model import builder as saved_model_builder from tensorflow.python.saved_model import ( signature_constants, signature_def_utils, tag_constants, utils) from tensorflow.python.util import compat model_path = "model" model_version = 1 model_signature = signature_def_utils.build_signature_def( inputs={ "keys": utils.build_tensor_info(keys_placeholder), "features": utils.build_tensor_info(inference_features) }, outputs={ "keys": utils.build_tensor_info(keys_identity), "prediction": utils.build_tensor_info(inference_op), "softmax": utils.build_tensor_info(inference_softmax), }, method_name=signature_constants.PREDICT_METHOD_NAME) export_path = os.path.join(compat.as_bytes(model_path), compat.as_bytes(str(model_version))) legacy_init_op = tf.group(tf.tables_initializer(), name='legacy_init_op') builder = saved_model_builder.SavedModelBuilder(export_path) builder.add_meta_graph_and_variables( sess, [tag_constants.SERVING], clear_devices=True, signature_def_map={ signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: model_signature, }, legacy_init_op=legacy_init_op) builder.save()
Keras模型匯出
import os
import tensorflow as tf
from tensorflow.python.util import compat
def export_savedmodel(model):
model_path = "model"
model_version = 1
model_signature = tf.saved_model.signature_def_utils.predict_signature_def(
inputs={'input': model.input}, outputs={'output': model.output})
export_path = os.path.join(compat.as_bytes(model_path), compat.as_bytes(str(model_version)))
builder = tf.saved_model.builder.SavedModelBuilder(export_path)
builder.add_meta_graph_and_variables(
sess=K.get_session(),
tags=[tf.saved_model.tag_constants.SERVING],
clear_devices=True,
signature_def_map={
tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
model_signature
})
builder.save()
SavedModel模型結構
使用TensorFlow的API匯出SavedModel模型後,可以檢查模型的目錄結構如下,然後就可以直接使用開源工具來載入服務了。
模型上線
部署線上服務
部署離線裝置
待更新