TensorFlow API之tf.estimator.Estimator
阿新 • • 發佈:2018-12-10
tf.estimator.Estimator
Estimator class訓練和測試TF模型。Estimator
物件封裝好通過model_fn
指定的模型,給定輸入和其它超引數,返回ops執行training, evaluation or prediction. 所有的輸出(包含checkpoints, event files, etc.)被寫入model_dir
。
屬性
- config
傳入
model_fn
,如果model_fn
有引數named “config” - model_dir
- model_fn
The model_fn with following signature:
def model_fn(features, labels, mode, config)
- params
方法
__init__
__init__(
model_fn,
model_dir=None,
config=None,
params=None # 將要傳入model_fn的超引數字典
)
evaluate
對訓練模型評價
evaluate( input_fn, # 輸入函式,返回元組features和labels steps=None, hooks=None, # List of SessionRunHook subclass instances checkpoint_path=None, # if none, 用model_dir中latest checkpoint name=None )
export_savemodel
匯出inference graph作為一個SavedModel
export_savedmodel(
export_dir_base, # 目錄
serving_input_receiver_fn, # 返回ServingInputReceiver的函式
assets_extra=None,
as_text=False,
checkpoint_path=None
)
-
get_variable_names
get_variable_names() 返回模型中所有變數名字的列表
-
get_variable_value(name) 根據變數name返回value
-
latest_checkpoint() 在
model_dir
中找到最近儲存的checkpoint -
predict 根據給定的features產生預測
predict(
input_fn,
predict_keys=None,
hooks=None,
checkpoint_path=None
)
- train
給定訓練資料後訓練model
train(
input_fn,
hooks=None,
steps=None,
max_steps=None,
saving_listeners=None
)