1. 程式人生 > 實用技巧 >bert文字分類模型儲存為savedmodel方式

bert文字分類模型儲存為savedmodel方式

預設bert是ckpt,在進行後期優化和部署時,savedmodel方式更加友好寫。

train完成後,呼叫如下函式:

def save_savedmodel(estimator, serving_dir, seq_length, is_tpu_estimator):
    feature_map = {
        "input_ids": tf.placeholder(tf.int32, shape=[None, seq_length], name='input_ids'),
        "input_mask": tf.placeholder(tf.int32, shape=[None, seq_length], name='
input_mask'), "segment_ids": tf.placeholder(tf.int32, shape=[None, seq_length], name='segment_ids'), "label_ids": tf.placeholder(tf.int32, shape=[None], name='label_ids'), } serving_input_receiver_fn = tf.estimator.export.build_raw_serving_input_receiver_fn(feature_map) estimator.export_savedmodel(serving_dir, serving_input_receiver_fn, strip_default_attrs
=True) print("儲存savedmodel")

estimator:estimator = Estimator(model_fn=model_fn,params={},config=run_config)

serving_dir:儲存目錄

seq_length:樣本長度

is_tpu_estimator: tpu標誌位