1. 程式人生 > >tensorflow estimator api train時的 checkpoint save 行為 和 val時的chekpoint skip行為

tensorflow estimator api train時的 checkpoint save 行為 和 val時的chekpoint skip行為

INFO:tensorflow:Create CheckpointSaverHook.
2018-01-15 16:24:33.513942: I tensorflow/core/platform/cpu_feature_guard.cc:137] Your CPU supports instructions that this TensorFlow binary was not compiled to use: SSE4.1 SSE4.2 AVX AVX2 FMA
2018-01-15 16:24:34.390763: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1030
] Found device 0 with properties: name: GeForce GTX 1080 Ti major: 6 minor: 1 memoryClockRate(GHz): 1.582 pciBusID: 0000:89:00.0 totalMemory: 10.91GiB freeMemory: 10.75GiB 2018-01-15 16:24:34.390813: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1120] Creating TensorFlow device (/device:GPU:0) -> (device: 0, name: GeForce GTX 1080
Ti, pci bus id: 0000:89:00.0, compute capability: 6.1) 2018-01-15 16:25:58.010092: I tensorflow/core/kernels/shuffle_dataset_op.cc:110] Filling up shuffle buffer (this may take a while): 499 of 1000 2018-01-15 16:26:07.689469: I tensorflow/core/kernels/shuffle_dataset_op.cc:121] Shuffle buffer filled. INFO:tensorflow:Saving checkpoints for
1 into /train/mymodels/model.ckpt. INFO:tensorflow:loss = 22.2663, step = 1 ...... EBUG:tensorflow:Skipping evaluation due to same checkpoint /train/mymodels/model.ckpt-1 for step 100 as for step 50.

執行流程如下:

experiment.train_and_evaluate()

# 驗證部分用hook實現, 
if self._min_eval_frequency:
   self._train_monitors += [
       monitors.ValidationMonitor(
           input_fn=self._eval_input_fn,
           eval_steps=self._eval_steps,
           metrics=self._eval_metrics,
           every_n_steps=self._min_eval_frequency,
           name=eval_dir_suffix,
           hooks=self._eval_hooks)
   ]

# 訓練部分最終呼叫estimator._train_model(), 第一次訓練會儲存一下快照!!!
self.train(delay_secs=0)

訓練部分

experiment.train(delay_secs=0) -> experiment._estimator.train-> estimator._train_model()

#estimator._train_model()程式碼
# ...
      # 1. 增加loss監控 (通過hooks)
      # Check if the user created a loss summary, and add one if they didn't.
      # We assume here that the summary is called 'loss'. If it is not, we will
      # make another one with the name 'loss' to ensure it shows up in the right
      # graph in TensorBoard.
      if not any([x.op.name == 'loss'
                  for x in ops.get_collection(ops.GraphKeys.SUMMARIES)]):
        summary.scalar('loss', estimator_spec.loss)
      ops.add_to_collection(ops.GraphKeys.LOSSES, estimator_spec.loss)
      worker_hooks.extend(hooks)
      worker_hooks.extend([
          training.NanTensorHook(estimator_spec.loss),
          training.LoggingTensorHook(
              {
                  'loss': estimator_spec.loss,
                  'step': global_step_tensor
              },
              every_n_iter=100)
      ])
      worker_hooks.extend(estimator_spec.training_hooks)

     # 2. 建立saver 如果沒有提供saver則建立
     if not (estimator_spec.scaffold.saver or
              ops.get_collection(ops.GraphKeys.SAVERS)):
        ops.add_to_collection(
            ops.GraphKeys.SAVERS,
            training.Saver(
                sharded=True,
                max_to_keep=self._config.keep_checkpoint_max,
                keep_checkpoint_every_n_hours=(
                    self._config.keep_checkpoint_every_n_hours),
                defer_build=True,
                save_relative_paths=True))

      chief_hooks = []
      all_hooks = worker_hooks + list(estimator_spec.training_chief_hooks)
      saver_hooks = [
          h for h in all_hooks if isinstance(h, training.CheckpointSaverHook)]
      if (self._config.save_checkpoints_secs or
          self._config.save_checkpoints_steps):
        if not saver_hooks:
          # 3.  checkpoint saver hooks 這是checkpoint儲存的關鍵點
          chief_hooks = [
              training.CheckpointSaverHook(
                  self._model_dir,
                  save_secs=self._config.save_checkpoints_secs,
                  save_steps=self._config.save_checkpoints_steps,
                  scaffold=estimator_spec.scaffold)
          ]
          saver_hooks = [chief_hooks[0]]

CheckpointSaverHook

class CheckpointSaverHook(session_run_hook.SessionRunHook):
  """Saves checkpoints every N steps or seconds."""

  def __init__(self,
               checkpoint_dir,
               save_secs=None,
               save_steps=None,
               saver=None,
               checkpoint_basename="model.ckpt",
               scaffold=None,
               listeners=None):
    """Initializes a `CheckpointSaverHook`.

    Args:
      checkpoint_dir: `str`, base directory for the checkpoint files.
      save_secs: `int`, save every N secs.
      save_steps: `int`, save every N steps.
      saver: `Saver` object, used for saving.
      checkpoint_basename: `str`, base name for the checkpoint files.
      scaffold: `Scaffold`, use to get saver object.
      listeners: List of `CheckpointSaverListener` subclass instances.
        Used for callbacks that run immediately before or after this hook saves
        the checkpoint.

    Raises:
      ValueError: One of `save_steps` or `save_secs` should be set.
      ValueError: At most one of saver or scaffold should be set.
    """
    logging.info("Create CheckpointSaverHook.")
    if saver is not None and scaffold is not None:
      raise ValueError("You cannot provide both saver and scaffold.")
    self._saver = saver
    self._checkpoint_dir = checkpoint_dir
    self._save_path = os.path.join(checkpoint_dir, checkpoint_basename)
    self._scaffold = scaffold
    self._timer = SecondOrStepTimer(every_secs=save_secs,
                                    every_steps=save_steps)
    self._listeners = listeners or []

  def begin(self):
    self._summary_writer = SummaryWriterCache.get(self._checkpoint_dir)
    self._global_step_tensor = training_util._get_or_create_global_step_read()  # pylint: disable=protected-access
    if self._global_step_tensor is None:
      raise RuntimeError(
          "Global step should be created to use CheckpointSaverHook.")
    for l in self._listeners:
      l.begin()

  def before_run(self, run_context):  # pylint: disable=unused-argument
    if self._timer.last_triggered_step() is None:
      # We do write graph and saver_def at the first call of before_run.
      # We cannot do this in begin, since we let other hooks to change graph and
      # add variables in begin. Graph is finalized after all begin calls.
      training_util.write_graph(
          ops.get_default_graph().as_graph_def(add_shapes=True),
          self._checkpoint_dir,
          "graph.pbtxt")
      saver_def = self._get_saver().saver_def if self._get_saver() else None
      graph = ops.get_default_graph()
      meta_graph_def = meta_graph.create_meta_graph_def(
          graph_def=graph.as_graph_def(add_shapes=True),
          saver_def=saver_def)
      self._summary_writer.add_graph(graph)
      self._summary_writer.add_meta_graph(meta_graph_def)

    return SessionRunArgs(self._global_step_tensor)

  def after_run(self, run_context, run_values):
    stale_global_step = run_values.results
    #這個函式很關鍵!!!!! 當 “第一次執行” 或者 “到了該執行checkpoint的時候” 它都會返回true
    if self._timer.should_trigger_for_step(stale_global_step+1):
      # get the real value after train op.
      global_step = run_context.session.run(self._global_step_tensor)
      if self._timer.should_trigger_for_step(global_step):
        self._timer.update_last_triggered_step(global_step)
        self._save(run_context.session, global_step)

  def end(self, session):
    last_step = session.run(self._global_step_tensor)
    if last_step != self._timer.last_triggered_step():
      self._save(session, last_step)
    for l in self._listeners:
      l.end(session, last_step)

  def _save(self, session, step):
    """Saves the latest checkpoint."""
    logging.info("Saving checkpoints for %d into %s.", step, self._save_path)

    for l in self._listeners:
      l.before_save(session, step)

    self._get_saver().save(session, self._save_path, global_step=step)
    self._summary_writer.add_session_log(
        SessionLog(
            status=SessionLog.CHECKPOINT, checkpoint_path=self._save_path),
        step)

    for l in self._listeners:
      l.after_save(session, step)

  def _get_saver(self):
    if self._saver is not None:
      return self._saver
    elif self._scaffold is not None:
      return self._scaffold.saver

    # Get saver from the SAVERS collection if present.
    collection_key = ops.GraphKeys.SAVERS
    savers = ops.get_collection(collection_key)
    if not savers:
      raise RuntimeError(
          "No items in collection {}. Please add a saver to the collection "
          "or provide a saver or scaffold.".format(collection_key))
    elif len(savers) > 1:
      raise RuntimeError(
          "More than one item in collection {}. "
          "Please indicate which one to use by passing it to the constructor.".
          format(collection_key))

    self._saver = savers[0]
    return savers[0]

SecondOrStepTimer.should_trigger_for_step

class SecondOrStepTimer(_HookTimer):
  """Timer that triggers at most once every N seconds or once every N steps.
  """

  def __init__(self, every_secs=None, every_steps=None):
    self.reset()
    self._every_secs = every_secs
    self._every_steps = every_steps

    if self._every_secs is None and self._every_steps is None:
      raise ValueError("Either every_secs or every_steps should be provided.")
    if (self._every_secs is not None) and (self._every_steps is not None):
      raise ValueError("Can not provide both every_secs and every_steps.")

    super(SecondOrStepTimer, self).__init__()

  def reset(self):
    self._last_triggered_step = None
    self._last_triggered_time = None

  def should_trigger_for_step(self, step):
    """Return true if the timer should trigger for the specified step.

    Args:
      step: Training step to trigger on.

    Returns:
      True if the difference between the current time and the time of the last
      trigger exceeds `every_secs`, or if the difference between the current
      step and the last triggered step exceeds `every_steps`. False otherwise.
    """
    # 如果是第一次執行
    if self._last_triggered_step is None:
      return True

    if self._last_triggered_step == step:
      return False

    if self._every_secs is not None:
      if time.time() >= self._last_triggered_time + self._every_secs:
        return True

    if self._every_steps is not None:
      if step >= self._last_triggered_step + self._every_steps:
        return True

    return False

  def update_last_triggered_step(self, step):
    current_time = time.time()
    if self._last_triggered_time is None:
      elapsed_secs = None
      elapsed_steps = None
    else:
      elapsed_secs = current_time - self._last_triggered_time
      elapsed_steps = step - self._last_triggered_step

    self._last_triggered_time = current_time
    self._last_triggered_step = step
    return (elapsed_secs, elapsed_steps)

  def last_triggered_step(self):
    return self._last_triggered_step

驗證部分

# experiment.train_and_evaluate()

 self._train_monitors += [
              monitors.ValidationMonitor(
                  input_fn=self._eval_input_fn,
                  eval_steps=self._eval_steps,
                  metrics=self._eval_metrics,
                  every_n_steps=self._min_eval_frequency,
                  name=eval_dir_suffix,
                  hooks=self._eval_hooks)
          ]
class ValidationMonitor(EveryN):
  """Runs evaluation of a given estimator, at most every N steps.

  Note that the evaluation is done based on the saved checkpoint, which will
  usually be older than the current step.

  Can do early stopping on validation metrics if `early_stopping_rounds` is
  provided.
  """

  # ...... 略

  def every_n_step_end(self, step, outputs):
    super(ValidationMonitor, self).every_n_step_end(step, outputs)

    # Check that we are not running evaluation on the same checkpoint.
    latest_path = saver_lib.latest_checkpoint(self._estimator.model_dir)
    if latest_path is None:
      logging.debug("Skipping evaluation since model has not been saved yet "
                    "at step %d.", step)
      return False
    if latest_path is not None and latest_path == self._latest_path:
      # 防止重複!!!!
      logging.debug("Skipping evaluation due to same checkpoint %s for step %d "
                    "as for step %d.", latest_path, step,
                    self._latest_path_step)
      return False
    self._latest_path = latest_path
    self._latest_path_step = step