tensorflow estimator 與 model_fn 是這樣溝通的
在自定義估計器過程中,搞清Estimator 與model_fn 及其他引數之間的關係十分中重要!總結一下,就是
estimator 拿著獲取到的引數往model_fn裡面灌,model_fn 是作為用資料的關鍵使用者。
與scikit-learn和spark中的各種估計器相比,tensorflow的估計器抽象程度更高,因為他將各種由超引數知道構建的
模型作為引數傳入,estimator的結構和定義不會因為模型的變化帶來特別大的變化;而spark,scikit-learn中,估計器
往往因演算法不同而有不同構造,TensorFlow的引數化程度更高,有更高自由度,因而引數管理就與前兩者有所不同!
總之,Estimator要使用傳入的資料就必須瞭解傳入的資料,java有種型別控制,Python中鴨子判斷檢查,或者有元資料幫忙瞭解傳入的資料,
或者大家有默契約定,或者有明顯的協議!Esimator和mode_fn之間沒有強制約束,靠大家默契約定,約定內容就在下面的英文描述中。
Depending on the value of mode
, different arguments are required. Namely
* For `mode == ModeKeys.TRAIN`: required fields are `loss` and `train_op`. * For `mode == ModeKeys.EVAL`: required field is `loss`. * For `mode == ModeKeys.PREDICT`: required fields are `predictions`.
class Estimator(object):
"""Estimator class to train and evaluate TensorFlow models.
The Estimator
object wraps a model which is specified by a model_fn
,
which, given inputs and a number of other parameters, returns the ops
necessary to perform training, evaluation, or predictions.
All outputs (checkpoints, event files, etc.) are written to model_dir
subdirectory thereof. If
model_dir
is not set, a temporary directory isused.
The config
argument can be passed tf.estimator.RunConfig
object containing
information about the execution environment. It is passed on to the
model_fn
, if the model_fn
has a parameter named "config" (and input
functions in the same manner). If the config
parameter is not passed, it is
instantiated by the Estimator
. Not passing config means that defaults useful
for local execution are used. Estimator
makes config available to the model
(for instance, to allow specialization based on the number of workers
available), and also uses some of its fields to control internals, especially
regarding checkpointing.
The params
argument contains hyperparameters. It is passed to the
model_fn
, if the model_fn
has a parameter named "params", and to the input
functions in the same manner. Estimator
only passes params along, it does
not inspect it. The structure of params
is therefore entirely up to the
developer.
None of Estimator
's methods can be overridden in subclasses (its
constructor enforces this). Subclasses should use model_fn
to configure
the base class, and may add methods implementing specialized functionality.
@compatibility(eager)
Calling methods of Estimator
will work while eager execution is enabled.
However, the model_fn
and input_fn
is not executed eagerly, Estimator
will switch to graph model before calling all user-provided functions (incl.
hooks), so their code has to be compatible with graph mode execution. Note
that input_fn
code using tf.data
generally works in both graph and eager
modes.
@end_compatibility
"""
def init(self, model_fn, model_dir=None, config=None, params=None,
warm_start_from=None):
"""Constructs an Estimator
instance.
See [estimators](https://tensorflow.org/guide/estimators) for more
information.
To warm-start an `Estimator`:
```python
estimator = tf.estimator.DNNClassifier(
feature_columns=[categorical_feature_a_emb, categorical_feature_b_emb],
hidden_units=[1024, 512, 256],
warm_start_from="/path/to/checkpoint/dir")
```
For more details on warm-start configuration, see
`tf.estimator.WarmStartSettings`.
Args:
model_fn: Model function. Follows the signature:
* Args:
* `features`: This is the first item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `tf.Tensor` or `dict` of same.
* `labels`: This is the second item returned from the `input_fn`
passed to `train`, `evaluate`, and `predict`. This should be a
single `tf.Tensor` or `dict` of same (for multi-head models).
If mode is `tf.estimator.ModeKeys.PREDICT`, `labels=None` will
be passed. If the `model_fn`'s signature does not accept
`mode`, the `model_fn` must still be able to handle
`labels=None`.
* `mode`: Optional. Specifies if this training, evaluation or
prediction. See `tf.estimator.ModeKeys`.
* `params`: Optional `dict` of hyperparameters. Will receive what
is passed to Estimator in `params` parameter. This allows
to configure Estimators from hyper parameter tuning.
* `config`: Optional `estimator.RunConfig` object. Will receive what
is passed to Estimator as its `config` parameter, or a default
value. Allows setting up things in your `model_fn` based on
configuration such as `num_ps_replicas`, or `model_dir`.
* Returns:
`tf.estimator.EstimatorSpec`
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into an estimator to
continue training a previously saved model. If `PathLike` object, the
path will be resolved. If `None`, the model_dir in `config` will be used
if set. If both are set, they must be same. If both are `None`, a
temporary directory will be used.
config: `estimator.RunConfig` configuration object.
params: `dict` of hyper parameters that will be passed into `model_fn`.
Keys are names of parameters, values are basic python types.
warm_start_from: Optional string filepath to a checkpoint or SavedModel to
warm-start from, or a `tf.estimator.WarmStartSettings`
object to fully configure warm-starting. If the string
filepath is provided instead of a
`tf.estimator.WarmStartSettings`, then all variables are
warm-started, and it is assumed that vocabularies
and `tf.Tensor` names are unchanged.
Raises:
ValueError: parameters of `model_fn` don't match `params`.
ValueError: if this is called via a subclass and if that class overrides
a member of `Estimator`.
"""