1. 程式人生 > 其它 >[原始碼解析] TensorFlow 分散式 DistributedStrategy 之基礎篇

[原始碼解析] TensorFlow 分散式 DistributedStrategy 之基礎篇

[原始碼解析] TensorFlow 分散式 DistributedStrategy 之基礎篇

目錄

前文之中我們已經介紹了 Strategy 這個基本概念,tf.distribute.Strategy 是一個可在多個 GPU、多臺機器或 TPU 上進行分散式訓練的 TensorFlow API。使用此 API,您只需改動較少程式碼就能基於現有模型和訓練程式碼來實現單機多卡,多機多卡等情況的分散式訓練。tf.distribute.Strategy 旨在實現以下目標:

  • 覆蓋不同維度的使用者用例。
  • 易於使用,支援多種使用者(包括研究人員和 ML 工程師等)。
  • 提供開箱即用的高效能。
  • 從使用者模型程式碼之中解耦,這樣可以輕鬆切換策略。
  • 支援 Custom Training Loop,Estimator,Keras。
  • 支援 eager excution。

從系統角度或者說從開發者的角度看,Strategy 是基於Python作用域或裝飾器來實現的一套機制。它提供了一組命名的分散式策略,如ParameterServerStrategy、CollectiveStrategy來作為Python作用域,這些策略可以被用來捕獲使用者函式中的模型宣告和訓練邏輯,其將在使用者程式碼開始時生效。在後端,分散式系統可以重寫計算圖,並根據選擇的策略(引數伺服器或集合)合併相應的語義。

因此我們分析的核心就是如何把資料讀取,模型引數,分散式計算融合到Python作用域或裝飾器之中,本章我們就從 Strategy 的類體系結構和讀取資料開始。

依然安利兩個大神:

[TensorFlow Internals] (https://github.com/horance-liu/tensorflow-internals),雖然其分析的不是最新程式碼,但是建議對 TF 內部實現機制有興趣的朋友都去閱讀一下,絕對大有收穫。
https://home.cnblogs.com/u/deep-learning-stacks/ 西門宇少,不僅僅是 TensorFlow,其公共號還有更多其他領域,業界前沿。

本系列其他文章是:

[翻譯] TensorFlow 分散式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"

[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構

[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache

[原始碼解析] TensorFlow 分散式環境(5) --- Session

[原始碼解析] TensorFlow 分散式環境(7) --- Worker 動態邏輯

[原始碼解析] TensorFlow 分散式環境(8) --- 通訊機制

[翻譯] 使用 TensorFlow 進行分散式訓練

1. StrategyBase

StrategyBase 是一個裝置列表之上的狀態和計算分佈策略。是 v1 策略和 v2 策略類的基類。

1.1 初始化

StrategyBase 初始化方法之中最主要就是設定 extended,其型別是 StrategyExtendedV2 或者 StrategyExtendedV1。

class StrategyBase(object):

  def __init__(self, extended):
    
    self._extended = extended

    # Flag that is used to indicate whether distribution strategy is used with
    # Estimator. This is required for backward compatibility of loss scaling
    # when using v1 optimizer with estimator.
    self._scale_loss_for_estimator = False

    if not hasattr(extended, _retrace_functions_for_each_device):
        # extended._retrace_functions_for_each_device dictates
      # whether the same function will be retraced when it is called on
      # different devices.
      try:
        extended._retrace_functions_for_each_device = (
            len(extended.worker_devices) > 1)
        distribution_strategy_replica_gauge.get_cell(num_replicas).set(
            self.num_replicas_in_sync)
      except:  
        # Default for the case where extended.worker_devices can't return
        # a sensible value.
        extended._retrace_functions_for_each_device = True

    # Below are the dicts of axis(int) -> tf.function.
    self._mean_reduce_helper_fns = {}
    self._reduce_sum_fns = {}

    # Whether this strategy is designed to work with ClusterCoordinator.
    self._should_use_with_coordinator = False

  @property
  def extended(self):
    ```tf.distribute.StrategyExtended with additional methods.```
    return self._extended

1.2 使用

如果想使用 Keras compile/fit,請參照 https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_keras。

也可以將 tf.distribution.Strategy 的派生類傳遞給 tf.estimator.RunConfig 來指定 tf.estimator.Estimator 應該如何分配計算,具體可以參照 https://www.tensorflow.org/guide/distributed_training#using_tfdistributestrategy_with_estimator_limited_support。

在建立和執行模型時,應該首先使用 tf.distribution.Strategy.scope 來指定一個策略。 指定策略意味著這將使程式碼處於這個策略的 cross-replica context 中,因此這個策略將負責控制比如 variable placement 這樣的功能。

1.3 CTL

如果您正在編寫一個自定義的訓練迴圈(custom training loop),您將需要多呼叫一些方法,

  • 使用 tf.distribut.Strategy.experimental_distribute_dataset 將 tf.data.Dataset 轉換,使之能產生 per-replica 值。如果您想手動指定資料集如何在各個副本之間進行劃分,請使用tf.distribut.Strategy.distribut_datasets_from_function。
  • 使用 tf.distribution.Strategy.run 為每個副本執行函式,該函式使用 per-replica 的值(例如來自tf.distribution.DistributedDataset物件)並返回一個 per-replica。這個函式是在 副本上下文 中執行的,這意味著每個操作都在每個副本上單獨執行。
  • 最後使用一個方法(如tf.distributed.Strategy.reduce)將得到的 per-replica 的值轉換成普通的張量。

下面程式碼是 CTL 一個典型用例,其使用一個普通的 dataset 和 replica_fn 在名為 my_strategy 的特定 tf.distribution.Strategy 下分散式執行。在 replica_fn 中建立的任何變數都是使用 my_strategy 的策略建立的。

使用者可以使用 reduce API 來聚合各副本的結果,並將其作為對 tf.distributedDataset 進行一次迭代的返回值。使用者也可以使用 tf.keras.metrics(如損失、準確度等)來累積各步驟的度量。

  with my_strategy.scope():
    @tf.function
    def distribute_train_epoch(dataset):
      def replica_fn(input):
        # process input and return result
        return result

      total_result = 0
      for x in dataset:
        per_replica_result = my_strategy.run(replica_fn, args=(x,))
        total_result += my_strategy.reduce(tf.distribute.ReduceOp.SUM,
                                           per_replica_result, axis=None)
      return total_result

    dist_dataset = my_strategy.experimental_distribute_dataset(dataset)
    for _ in range(EPOCHS):
      train_result = distribute_train_epoch(dist_dataset)

1.4 Scope

分發策略的範圍(作用域)決定了如何建立變數以及在何處建立變數,比如對於 MultiWorkerMirroredStrategy 而言,建立的變數型別是 MirroredVariable ,策略將它們複製到每個工作者之上。Scope 的方法主要是通過呼叫 _extended._scope 來完成。該方法返回了一個 Context manager,這可以設定本策略為當前策略,並且分發變數。

def scope(self):
    
    """Context manager to make the strategy current and distribute variables.
   
    Returns:
    A context manager.
   """
  return self._extended._scope(self)  

1.4.1 使用

具體使用方法如下:

>>> strategy = tf.distribute.MirroredStrategy([GPU:0, GPU:1])
>>> # Variable created inside scope:
>>> with strategy.scope():
...   mirrored_variable = tf.Variable(1.)
>>> mirrored_variable
MirroredVariable:{
  0: <tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>,
  1: <tf.Variable 'Variable/replica_1:0' shape=() dtype=float32, numpy=1.0>
}
>>> # Variable created outside scope:
>>> regular_variable = tf.Variable(1.)
>>> regular_variable
<tf.Variable 'Variable:0' shape=() dtype=float32, numpy=1.0>

1.4.2 功能

當進入了 Strategy.scope 之後,會執行如下操作:

  • strategy 被安裝在全域性上下文內,作為當前策略。 在這個範圍內,呼叫 tf.distribution.get_strategy() 將返回這個策略。在這個範圍之外,它將返回預設的無操作(no-op)策略。
  • 進入這個 scope 也就進入了 cross-replica context。
  • scope 內的變數建立將被策略攔截。每個策略都定義了它要如何影響變數的建立。像 MirroredStrategy、TPUStrategy 和 MultiWorkerMiroredStrategy 這樣的同步策略在每個副本上建立變數,而ParameterServerStrategy 在引數伺服器上建立變數。這是在策略自定義的 tf.variable_creator_scope 之中完成的。
  • 在某些策略中也可以輸入預設的裝置範圍:在 MultiWorkerMiroredStrategy 中,每個 worker 上輸入的預設裝置範圍是 /CPU:0。

注意:進入 Scope 不會自動分配計算,除非是像 keras model.fit 這樣的高層訓練框架。如果您沒有使用 model.fit,您需要使用 strategy.run API 來明確分配該計算。

1.4.3 Scope 範圍

什麼在 Scope 之內?什麼在之外?

  • 任何建立分散式變數的操作都必須在 strategy.scope 中呼叫。這可以通過在範圍上下文中直接呼叫變數建立函式來實現,或者由 strategy.run 或 Keras.Model.fit 自動為您輸入。
  • 任何可能惰性建立變數的函式(例如,Model.call(),追蹤一個tf.function,等等)也應該在作用域內呼叫。
  • 變數建立的另一個來源可以是檢查點的恢復。
  • 任何在作用域之外建立的變數都不會被分發。

請注意,任何在策略內部建立的變數都會捕獲策略資訊。因此,在 strategy.scope 之外對這些變數的讀寫也可以無縫進行,而不需要使用者進入 scope。

一些需要進入策略範圍的策略 API(如strategy.run和strategy.reduce)會自動進入 scope,這意味著在使用這些API 時,您不需要自己明確進入 scope。

模型、優化器、Metrics 可以在 TF 之中建立變數,這樣的物件應該總是在作用域內初始化。當 tf.keras.Model 在strategy.scope 內被建立,Model 物件會捕獲範圍資訊。當高層的訓練框架方法,如 model.compile,model.fit 等被呼叫時,捕獲的範圍將被自動輸入,相關的策略將被用來分配訓練等。

警告:簡單地呼叫model(..)不會自動進入 Strategy 的範圍 -- 只有高水平的訓練框架 API 支援這種行為:model.compile、model.fit、model.evaluation、model.predict 和 model.save 都可以在範圍內或範圍外呼叫。

1.5 StrategyExtendedV2

StrategyExtendedV2 為需要分佈感知(distribution-aware)的演算法提供額外的API。

@tf_export(distribute.StrategyExtended, v1=[])
class StrategyExtendedV2(object):
    # Additional APIs for algorithms that need to be distribution-aware.

1.5.1 locality

tf.distributed.DistributedValues 可以具有與分散式變數相同的 locality,這導致 mirrored value 會駐留在與變數相同的裝置上(而不是計算裝置上)。針對 locality,使用者可以做如下操作:

  • 可以使用 tf.distribution.StrategyExtended.update 來更新變數的值。
  • 可以使用 tf.distribution.StrategyExtended.colocate_vars_with 來讓一個變數與另一個變數有相同的 locality。
  • 可以使用 tf.distribution.StrategyExtended.reduce_to 或 tf.distribution.StrategyExtended.batch_reduce_to 將 PerReplica value 轉換到另一個變數的 locality。

1.5.2 如何更新

接下來我們看看如何更新一個分散式變數(distributed variable)。分散式變數(distributed variable)是在多個裝置上建立的變數,比如映象變數和同步讀取(SyncOnRead)變數。更新分散式變數的標準模式是:

  1. 在傳遞給 tf.distribution.Strategy.run 的函式中來計算得到一個(update, variable)對列表。例如,更新可能是一個變數的損失梯度。
  2. 通過呼叫 tf.distribution.get_replica_context().merge_call() 來切換到 cross-replica 模式,呼叫時將更新和變數作為引數。
  3. 通過呼叫 tf.distribution.StrategyExtended.reduce_to(VariableAggregation.SUM, t, v)(針對一個變數)或tf.distribution.StrategyExtended.batch_reduce_to(針對一個變數列表)來對更新進行求和。
  4. 可以為每個變數呼叫 tf.distribution.StrategyExtended.update(v) 來更新它的值。

如果您在副本上下文中呼叫 tf.keras.optimizer.Optimizer.apply_gradients方法,則步驟 2 到 4 會由類tf.keras.optimizer.Optimizer 自動完成。

事實上,更新分散式變數的更高層次的解決方案是對該變數呼叫 assign,就像您對普通的 tf.Variable 一樣操作。您可以在 replica contextcross-replica context 中呼叫該方法。

對於一個 mirrored 變數,在 replica context 中呼叫 assign 需要在變數建構函式中指定aggregation型別。在這種情況下,您需要自行處理在步驟2到4中描述的上下文切換和同步。如果您在 cross-replica context 中對 mirrored variable 呼叫 assign,您只能 assign 一個值,或者從一個映象的 tf.distribution.DistributedValues 中來 assign 值。對於一個 _SyncOnRead 變數,在 replica 上下文中,您可以簡單地呼叫 assign,而不發生任何聚合。在 cross-replica context 中,您只能給一個 SyncOnRead 變數分配一個值。

1.6 繼承關係

Strategy 繼承關係如下,其中 V1 版本是一條路線,V2 版本又是一條路線。

圖 1 Strategy 繼承關係

Extended 繼承關係如下:

圖 2 Extended 繼承關係

至此,我們分析了Strategy的類體系,但是還沒有領略Strategy的精妙之處,我們需要繼續分析下去,本文會先看看如何處理資料,下一篇看看如何處理變數。

2. 讀取資料

我們接下來看看如何讀取資料。對於輸入資料集,主要有兩種實現:

  • experimental_distribute_dataset :從 tf.data.Dataset 生成 tf.distribute.DistributedDataset,得到的資料集可以像常規資料集一樣迭代讀取。
  • _distribute_datasets_from_function :通過呼叫 dataset_fn 來分發 tf.data.Dataset。

我們接下來用 MirroredStrategy 來分析如何讀取資料。總體的邏輯大致如下:在每個工作者上對資料集進行復制,重新分批和分片。首先會按檔案分片,這樣每個工作者將看到不同的檔案子集。如果無法做到,工作者則將嘗試對最終輸入進行分片,這樣每個工作者會執行整個預處理流水線,但是隻收到自己的資料集分片,從而達到資料並行的目的。

2.1 直接讀取資料集

2.1.1 用例

以下是如何使用 experimental_distribute_dataset 來直接得到資料集。

>>> global_batch_size = 2
>>> # Passing the devices is optional.
... strategy = tf.distribute.MirroredStrategy(devices=["GPU:0", "GPU:1"])
>>> # Create a dataset
... dataset = tf.data.Dataset.range(4).batch(global_batch_size)
>>> # Distribute that dataset
... dist_dataset = strategy.experimental_distribute_dataset(dataset)
>>> @tf.function
... def replica_fn(input):
...   return input*2
>>> result = []
>>> # Iterate over the tf.distribute.DistributedDataset
... for x in dist_dataset:
...   # process dataset elements
...   result.append(strategy.run(replica_fn, args=(x,)))
>>> print(result)
[PerReplica:{
  0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([0])>,
  1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([2])>
}, PerReplica:{
  0: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([4])>,
  1: <tf.Tensor: shape=(1,), dtype=int64, numpy=array([6])>
}]

2.1.2 基類實現

StrategyBase 方法之中,主要三種資料相關操作是:分批,分片,預取(大家可以回到PyTorch資料讀取部分看看異同)。

在上面的程式碼片段中,分批操作具體是:

  • dataset 首先按照 global_batch_size 進行分批。
  • 其次呼叫 experimental_distribute_dataset 把 dataset 按照一個新分批大小(batch size)進行重新分批,新分批大小等於"全域性分批大小除以同步副本數量"。使用者可以用 Pythonic for loop 來遍歷它。
  • x 是一個 tf.distribution.DistributedValues,其包含所有副本的資料,而每個副本會得到新批次大小的資料。
  • tf.distribution.Strategy.run 將負責把 x 中每個副本對應的資料(per-replica)分發給每個副本執行工作函式 replica_fn。

分片(Sharding)包含跨多個工作者的自動分片(autosharding)。

  • 首先,在多工作者(multi-worker)分散式訓練中(使用tf.distribution.experimental.MultiWorkerMirroredStrategy 或 tf.distribution.TPUStrategy 時),在一組工作者上自動分片(autosharding)資料集意味著每個工作者被分配了整個資料集的一個子集(如果設定了正確的tf.data.experimental.AutoShardPolicy)。這是為了確保在每個 step 中,每個工作者都會處理一個全域性的,包含不重疊的資料集元素的批次。自動分片有幾個不同的選項,可以使用 tf.data.experimental.DistributeOptions 來指定。
  • 然後,每個工作者內的分片意味著該方法將在所有工作者裝置之間分割資料(如果存在多個)。無論多工作者(multi-worker)是否設定自動分片,這都會發生。
  • 對於跨多個工作者的自動分片,預設模式是 tf.data.experimental.AutoShardPolicy.AUTO。如果資料集是從讀者資料集(例如tf.data.TFRecordDataset、tf.data.TextLineDataset等)中建立的,該模式將嘗試按檔案分片,否則按資料分片,其中每個工作者將讀取整個資料集,但是隻處理分配給它的分片。然而,如果每個工作者的輸入檔案少於一個,我們建議您通過設定 tf.data.experimental.DistributeOptions.auto_shard_policy 為 tf.data.experimental.AutoShardPolicy.OFF 來禁止跨工作者的資料集自動分片。

對於預取(prefetch),預設情況下,該方法在使用者提供的 tf.data.Dataset 例項的末尾新增一個預取轉換。預取轉換的引數是 buffer_size,就是同步的副本(replicas in sync)的數量。

experimental_distribute_dataset 的定義如下,其實就是呼叫 extended 來完成操作。

  def experimental_distribute_dataset(self, dataset, options=None):
    """Creates tf.distribute.DistributedDataset from tf.data.Dataset.

    Args:
      dataset: tf.data.Dataset that will be sharded across all replicas using
        the rules stated above.
      options: tf.distribute.InputOptions used to control options on how this
        dataset is distributed.

    Returns:
      A tf.distribute.DistributedDataset.
    """
    distribution_strategy_input_api_counter.get_cell(
        self.__class__.__name__, "distribute_dataset").increase_by(1)

    return self._extended._experimental_distribute_dataset(dataset, options)  

2.1.3 MirroredExtended 實現

我們用 MirroredExtended 來看看具體實現,其實就是呼叫 input_lib.get_distributed_dataset 來進行處理,因此我們深入到 input_lib 之中。

def _experimental_distribute_dataset(self, dataset, options):
  if (options and options.experimental_replication_mode ==
      distribute_lib.InputReplicationMode.PER_REPLICA):
    raise NotImplementedError(
        "InputReplicationMode.PER_REPLICA "
        "is only supported in "
        "distribute_datasets_from_function."
    )
  return input_lib.get_distributed_dataset(
      dataset,
      self._input_workers_with_options(options),
      self._container_strategy(),
      num_replicas_in_sync=self._num_replicas_in_sync,
      options=options)

2.1.4 input_lib 功能

input_lib 提供了關於處理輸入資料的一些基礎功能。get_distributed_dataset 是一個通用函式,其可以被所有策略用來返回分散式資料集。返回的分散式資料集例項是不同的,這取決於我們是在 TF1 還是 TF2 的背景下。返回的分散式資料集例項的 API 也有所不同。這裡用到了 DistributedDataset 和 input_workers,所以我們有必要一一進行分析。

def get_distributed_dataset(dataset,
                            input_workers,
                            strategy,
                            num_replicas_in_sync=None,
                            input_context=None,
                            options=None,
                            build=True):
  """Returns a distributed dataset from the given tf.data.Dataset instance.

  Args:
    dataset: a tf.data.Dataset instance.
    input_workers: an InputWorkers object which specifies devices on which
        iterators should be created.
    strategy: a tf.distribute.Strategy object, used to run all-reduce to
        handle last partial batch.
    num_replicas_in_sync: Optional integer. If this is not None, the value is
        used to decide how to rebatch datasets into smaller batches so that
        the total batch size for each step (across all workers and replicas)
        adds up to dataset's batch size.
    input_context: InputContext for sharding. Only pass this in for between
        graph multi-worker cases where there is only one input_worker. In
        these cases, we will shard based on the input_pipeline_id and
        num_input_pipelines in the InputContext.
    options: Default is None. tf.distribute.InputOptions used to control
        options on how this dataset is distributed.
    build: whether to build underlying datasets when a DistributedDataset is
        created. This is only useful for ParameterServerStrategy now.

  Returns:
    A distributed dataset instance.
  """
  if tf2.enabled():
    return DistributedDataset( # 接下來會分析 DistributedDataset
        input_workers,
        strategy,
        dataset,
        num_replicas_in_sync=num_replicas_in_sync,
        input_context=input_context,
        build=build,
        options=options)
  else:
    return DistributedDatasetV1(
        dataset,
        input_workers, # 接下來會分析 InputWorkers
        strategy,
        num_replicas_in_sync=num_replicas_in_sync,
        input_context=input_context,
        options=options)

2.1.5 InputWorkers

定義

InputWorkers 的作用是從輸入 worker 裝置到計算裝置的 1-to-many mapping。worker_device_pairs 就是對映關係列表,每個 item 是 (input device, a tuple of compute devices fed by that input device)。

class InputWorkers(object):
  """A 1-to-many mapping from input worker devices to compute devices."""

  # TODO(ishark): Remove option canonicalize_devices and make all the callers
  # pass canonicalized or raw device strings as relevant from strategy.
  def __init__(self,
               worker_device_pairs,
               canonicalize_devices=True):
    """Initialize an InputWorkers object.

    Args:
      worker_device_pairs: A sequence of pairs: (input device, a tuple of
        compute devices fed by that input device).
      canonicalize_devices: Whether to canonicalize devices for workers fully or
        partially. If False, it will partially canonicalize devices by removing
        job and task.
    """
    self._worker_device_pairs = worker_device_pairs
    self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs)
    self._canonicalize_devices = canonicalize_devices
    if canonicalize_devices:
      self._fed_devices = tuple(
          tuple(device_util.canonicalize(d)
                for d in f)
          for _, f in self._worker_device_pairs)
    else:
      self._fed_devices = tuple(
          tuple(device_util.canonicalize_without_job_and_task(d)
                for d in f)
          for _, f in self._worker_device_pairs)

  @property
  def num_workers(self):
    return len(self._input_worker_devices)

  @property
  def worker_devices(self):
    return self._input_worker_devices # 返回 device, worker 資訊

  def compute_devices_for_worker(self, worker_index):
    return self._fed_devices[worker_index]

  def __repr__(self):
    devices = self.worker_devices
    debug_repr = ",\n".join("  %d %s: %s" %
                            (i, devices[i], self._fed_devices[i])
                            for i in range(len(devices)))
    return "%s:{\n%s}" % (self.__class__.__name__, debug_repr)

  def serialize(self):
    return (self._worker_device_pairs, self._canonicalize_devices)

  def deserialize(self, serialized):
    return InputWorkers(serialized)

構建

在 MirroredStrategy 之中有成員變數 _input_workers,因此,如果呼叫時候就會生成 InputWorkers。

@property
def _input_workers(self):
  return self._input_workers_with_options()

_input_workers_with_options 會根據 self._devices 來進行配置,就是生成各種對映關係,然後配置進去。

def _input_workers_with_options(self, options=None):
  if not options: 
    # 沒有配置就直接建立
    return input_lib.InputWorkers(self._input_workers_devices)

  # 有配置就依據配置生成
  if (options.experimental_replication_mode ==
      distribute_lib.InputReplicationMode.PER_REPLICA):
    # PER_REPLICA 處理
    if options.experimental_place_dataset_on_device:
      self._input_workers_devices = (
          tuple(
              (device_util.canonicalize(d, d), (d,)) for d in self._devices))
    else:
      self._input_workers_devices = (
          tuple((device_util.canonicalize("/device:CPU:0", d), (d,))
                for d in self._devices))
    return input_lib.InputWorkers(self._input_workers_devices)
  else:
    if not options.experimental_fetch_to_device:
      return input_lib.InputWorkers([
          (host_device, (host_device,) * len(compute_devices))
          for host_device, compute_devices in self._input_workers_devices
      ])
    else:
      return input_lib.InputWorkers(self._input_workers_devices)

這裡使用了 device_util.canonicalize 方法,其作用是把裝置分類。

def canonicalize(d, default=None):
  """Canonicalize device string.

  If d has missing components, the rest would be deduced from the default
  argument or from '/replica:0/task:0/device:CPU:0'. For example:
    If d = '/cpu:0', default='/job:worker/task:1', it returns
      '/job:worker/replica:0/task:1/device:CPU:0'.
    If d = '/cpu:0', default='/job:worker', it returns
      '/job:worker/replica:0/task:0/device:CPU:0'.
    If d = '/gpu:0', default=None, it returns
      '/replica:0/task:0/device:GPU:0'.

  Note: This uses "job:localhost" as the default if executing eagerly.

  Args:
    d: a device string or tf.config.LogicalDevice
    default: a string for default device if d doesn't have all components.

  Returns:
    a canonicalized device string.
  """
  if isinstance(d, context.LogicalDevice):
    d = tf_device.DeviceSpec.from_string(d.name)
  else:
    d = tf_device.DeviceSpec.from_string(d)

  # Fill in missing device fields using defaults.
  result = tf_device.DeviceSpec(
      replica=0, task=0, device_type="CPU", device_index=0)
  if ops.executing_eagerly_outside_functions():
    # Try to deduce job, replica and task in case it's in a multi worker setup.
    host_cpu = tf_device.DeviceSpec.from_string(
        config.list_logical_devices("CPU")[0].name)
    if host_cpu.job:
      result = result.make_merged_spec(host_cpu)
    else:
      # The default job is localhost if eager execution is enabled
      result = result.replace(job="localhost")
  if default:
    # Overrides any defaults with values from the default device if given.
    result = result.make_merged_spec(
        tf_device.DeviceSpec.from_string(default))

  # Apply d last, so that it's values take precedence over the defaults.
  result = result.make_merged_spec(d)
  return result.to_string()

2.1.6 DistributedDataset

DistributedDataset 支援預先分發資料到多個裝置。

初始化

下面程式碼中省略了大量檢查程式碼,關鍵點是呼叫了 build 方法。

class DistributedDataset(_IterableInput, composite_tensor.CompositeTensor):
  """Distributed dataset that supports prefetching to multiple devices."""

  def __init__(self,
               input_workers,
               strategy,
               dataset=None,
               num_replicas_in_sync=None,
               input_context=None,
               components=None,
               element_spec=None,
               enable_get_next_as_optional=None,
               build=True,
               options=None):
    """Distribute the dataset on all workers.

    If num_replicas_in_sync is not None, we split each batch of the dataset
    into num_replicas_in_sync smaller batches, to be distributed among that
    worker's replicas, so that the batch size for a global step (across all
    workers and replicas) is as expected.

    Args:
      input_workers: an InputWorkers object.
      strategy: a tf.distribute.Strategy object, used to run all-reduce to
        handle last partial batch.
      dataset: tf.data.Dataset that will be used as the input source. Either
        dataset or components field should be passed when constructing
        DistributedDataset. Use this when contructing DistributedDataset from a
        new tf.data.Dataset. Use components when constructing using
        DistributedDatasetSpec.
      num_replicas_in_sync: Optional integer. If this is not None, the value
        is used to decide how to rebatch datasets into smaller batches so that
        the total batch size for each step (across all workers and replicas)
        adds up to dataset's batch size.
      input_context: InputContext for sharding. Only pass this in for between
        graph multi-worker cases where there is only one input_worker. In
        these cases, we will shard based on the input_pipeline_id and
        num_input_pipelines in the InputContext.
      components: datasets when DistributedDataset is constructed from
        DistributedDatasetSpec. Either field dataset or components should be
        passed.
      element_spec: element spec for DistributedDataset when constructing from
        DistributedDatasetSpec. This will be used to set the element_spec for
        DistributedDataset and verified against element_spec from components.
      enable_get_next_as_optional: this is required when components is passed
        instead of dataset.
      build: whether to build underlying datasets when this object is created.
        This is only useful for ParameterServerStrategy now.
      options: tf.distribute.InputOptions used to control options on how this
        dataset is distributed.
    """
    super(DistributedDataset, self).__init__(input_workers=input_workers)

    self._input_workers = input_workers
    self._strategy = strategy
    self._options = options
    self._input_context = input_context
    self._num_replicas_in_sync = num_replicas_in_sync

    if dataset is not None:
      self._original_dataset = dataset
      self._built = False
      if build:
        self.build() # 這裡是關鍵
    else:
      self._cloned_datasets = components
      self._cardinality = _cardinality(self._cloned_datasets[0])
      self._enable_get_next_as_optional = enable_get_next_as_optional

      if element_spec != _create_distributed_tensor_spec(
          self._strategy, self._cloned_datasets[0].element_spec):
        raise ValueError("Mismatched element_spec from the passed components")
      self._element_spec = element_spec

      self._built = True

建立資料

build 主要作用是呼叫 _create_cloned_datasets_from_dataset。

def build(self, dataset_to_replace=None):
  dataset = dataset_to_replace or self._original_dataset
  self._cardinality = _cardinality(dataset)
  self._enable_get_next_as_optional = _enable_get_next_as_optional(
      self._strategy, dataset, self._cardinality)
  self._create_cloned_datasets_from_dataset(dataset, self._input_context,
                                            self._input_workers,
                                            self._strategy,
                                            self._num_replicas_in_sync)
  self._element_spec = _create_distributed_tensor_spec(
      self._strategy, self._cloned_datasets[0].element_spec)
  self._built = True

_create_cloned_datasets_from_dataset 在每個工作者上對資料集進行克隆和分片(這裡就使用到了InputWorkers以獲取裝置資訊)。首先會嘗試按檔案分片,以便每個工作者看到不同的檔案子集。如果無法做到,則將嘗試對最終輸入進行分片,這樣每個工作者將執行整個預處理管道,並且只收到自己的資料集分片。

此外,_create_cloned_datasets_from_dataset 將每個工作者上的資料集重新匹配成 num_replicas_in_sync 個更小的批次。這些更小的批次分佈在該工作者的副本中,這樣全域性步驟(global step)的批次大小(跨越所有工作者和副本)加起來就等於原始資料集的批次大小。

def _create_cloned_datasets_from_dataset(self, dataset, input_context,
                                         input_workers, strategy,
                                         num_replicas_in_sync):
  if num_replicas_in_sync is not None:
    num_workers = input_context.num_input_pipelines if input_context else len(
        input_workers.worker_devices)
    # 用 _make_rebatch_fn 來重新 batch 資料
    rebatch_fn = self._make_rebatch_fn(dataset, num_workers,
                                       num_replicas_in_sync)
  else:
    rebatch_fn = None
  self._cloned_datasets = []
  
  if input_context:
    # Between-graph where we rely on the input_context for sharding
    if rebatch_fn is not None:
      dataset = rebatch_fn(dataset, input_context.input_pipeline_id)
    dataset = input_ops.auto_shard_dataset(dataset,
                                           input_context.num_input_pipelines,
                                           input_context.input_pipeline_id,
                                           num_replicas_in_sync)
    self._cloned_datasets.append(dataset)
  else:
    # 複製資料,返回 _RemoteDataset
    replicated_ds = distribute.replicate(dataset,
                                         input_workers.worker_devices)
    
    for i, worker in enumerate(input_workers.worker_devices):
      with ops.device(worker):
        cloned_dataset = replicated_ds[worker] # 找到某 worker 對應的資料集
        if rebatch_fn is not None:
          cloned_dataset = rebatch_fn(cloned_dataset, i) # 重新 batch,返回 _RebatchDataset
        # 自動分割槽,返回 _AutoShardDataset
        cloned_dataset = input_ops.auto_shard_dataset(
            cloned_dataset, len(input_workers.worker_devices), i,
            num_replicas_in_sync)
        self._cloned_datasets.append(cloned_dataset)

distribute.replicate 是用來複制資料,把資料複製到一系列裝置上,這裡返回 _RemoteDataset。

def replicate(dataset, devices):
  """A transformation that replicates dataset onto a list of devices.

  Args:
    dataset: A tf.data.Dataset object.
    devices: A list of devices to replicate the dataset on.

  Returns:
    A dictionary mapping device name to a dataset on that device.
  """

  dataset_device = dataset._variant_tensor.device

  datasets = {}
  if len(devices) == 1 and devices[0] == dataset_device:
    datasets[devices[0]] = dataset
    return datasets

  with ops.colocate_with(dataset._variant_tensor):
    dataset = dataset._apply_debug_options()
    graph_def = dataset._as_serialized_graph(
        strip_device_assignment=True,
        external_state_policy=ExternalStatePolicy.WARN)
    
  for device in devices: # 遍歷裝置,複製資料到裝置上,每個裝置一個 _RemoteDataset
    ds = _RemoteDataset(graph_def, device, dataset.element_spec)
    datasets[device] = ds
    
  return datasets

_make_rebatch_fn 返回一個把輸入資料集 rebatches 的 callable,這裡返回 _RebatchDataset。

def _make_rebatch_fn(self, dataset, num_workers, num_replicas_in_sync):
  """Returns a callable that rebatches the input dataset.

  Args:
    dataset: A tf.data.Dataset representing the dataset to be distributed.
    num_workers: An integer representing the number of workers to distribute
      dataset among.
    num_replicas_in_sync: An integer representing the number of replicas in
      sync across all workers.
  """
  if num_replicas_in_sync % num_workers:
    raise ValueError(
        "tf.distribute expects every worker to have the same number of "
        "replicas. However, encountered num_replicas_in_sync ({}) that "
        "cannot be divided by num_workers ({})".format(
            num_replicas_in_sync, num_workers))

  num_replicas_per_worker = num_replicas_in_sync // num_workers
  with ops.colocate_with(dataset._variant_tensor):  
    batch_size = distribute.compute_batch_size(dataset)

  def rebatch_fn(dataset, worker_index):
    try:
      def apply_rebatch():
        batch_sizes = distribute.batch_sizes_for_worker(
            batch_size, num_workers, num_replicas_per_worker, worker_index)
        return distribute._RebatchDataset(
            dataset, batch_sizes).prefetch(num_replicas_per_worker)

      def apply_legacy_rebatch():
        return distribute._LegacyRebatchDataset(
            dataset, num_replicas_in_sync).prefetch(num_replicas_per_worker)

      with ops.colocate_with(dataset._variant_tensor):
        return control_flow_ops.cond(
            math_ops.not_equal(batch_size, -1),
            true_fn=apply_rebatch,
            false_fn=apply_legacy_rebatch)
    except errors.InvalidArgumentError as e:
      if "without encountering a batch" in str(e):
        six.reraise(
            ValueError,
            ValueError(
                "Call the batch method on the input Dataset in order to be "
                "able to split your input across {} replicas.\n Please see "
                "the tf.distribute.Strategy guide. {}".format(
                    num_replicas_in_sync, e)),
            sys.exc_info()[2])
      else:
        raise

  return rebatch_fn

接下來是自動分片,這裡返回 _AutoShardDataset。

def auto_shard_dataset(dataset, num_shards, index, num_replicas_in_sync=None):
  """Shard the input pipeline by sharding the underlying list of files.

  Args:
    dataset: A tf.data.Dataset instance, typically the result of a bunch of
      dataset transformations.
    num_shards: A tf.int64 scalar tf.Tensor, representing the number of
        shards operating in parallel. Same usage as in tf.data.Dataset.shard.
    index: A tf.int64 scalar tf.Tensor, representing the worker index.
      Same usage as in tf.data.Dataset.shard.
    num_replicas_in_sync: An integer representing the total number of replicas
      across all workers. This is used in the rewrite when sharding by data.

  Returns:
    A modified Dataset obtained by updating the pipeline sharded by the
    files. The input dataset will be returned if we cannot automatically
    determine a good way to shard the input dataset.
  """
  if (dataset.options().experimental_distribute.auto_shard_policy !=
      AutoShardPolicy.OFF):
    if num_replicas_in_sync is None:
      num_replicas_in_sync = 1
    if isinstance(dataset, dataset_ops.DatasetV1):
      return distribute._AutoShardDatasetV1(dataset, num_shards, index,
                                            num_replicas_in_sync)
    else:
      return distribute._AutoShardDataset(dataset, num_shards, index,
                                          num_replicas_in_sync)
  else:
    return dataset

此時流程圖如下,可以看到資料集功能逐漸加強,首先是 _RemoteDataset,然後升級到 _AutoShardDataset。

圖 3 建立資料集

資料集

因為上面涉及了幾種資料集,所以我們要再仔細梳理一下這其中的關係,其具體可以理解為在資料集 DatasetV2 基礎之上逐步新增功能,最終返回給使用者。

_RemoteDataset 對應遠端資料集。

_RemoteDataset 繼承了 DatasetSource。dataset_ops.DatasetSource 繼承 DatasetV2(就是data.Dataset)。

class DatasetSource(DatasetV2):
  """Abstract class representing a dataset with no inputs."""
 
@tf_export("data.Dataset", v1=[])
@six.add_metaclass(abc.ABCMeta)
class DatasetV2(collections_abc.Iterable, tracking_base.Trackable,
                composite_tensor.CompositeTensor):

具體 _RemoteDataset 如下,其利用with ops.device(device)把資料集設定到遠端裝置上。

class _RemoteDataset(dataset_ops.DatasetSource):
  """Creates a dataset on a given device given a graph def."""

  def __init__(self, graph_def, device, element_spec):
    self._elem_spec = element_spec
    with ops.device(device): # 這裡會把資料集設定到遠端裝置上
      variant_tensor = ged_ops.dataset_from_graph(graph_def)
    super(_RemoteDataset, self).__init__(variant_tensor)

  @property
  def element_spec(self):
    return self._elem_spec

_RebatchDataset 代表重新分批,具體使用參見如下:

  ds = tf.data.Dataset.range(8)
  ds = ds.batch(4)
  ds = _RebatchDataset(ds, batch_sizes=[2, 1, 1])
  for elem in ds:
    print(elem)
  >> [0, 1], [2], [3], [4, 5], [6], [7]

  ds = tf.data.Dataset.range(16)
  ds = ds.batch(4)
  ds = _RebatchDataset(ds, batch_sizes=[6])
  for elem in ds:
    print(elem)
  >> [0, 1, 2, 3, 4, 5], [6, 7, 8, 9, 10, 11], [12, 13, 14, 15]

具體程式碼如下:

class _RebatchDataset(dataset_ops.UnaryDataset):
  """A Dataset that rebatches elements from its input into new batch sizes.

  _RebatchDataset(input_dataset, batch_sizes) is functionally equivalent to
  input_dataset.unbatch().batch(N), where the value of N cycles through the
  batch_sizes input list. The elements produced by this dataset have the same
  rank as the elements of the input dataset.

  """

  def __init__(self, input_dataset, batch_sizes, drop_remainder=False):
    """Creates a _RebatchDataset.

    Args:
      input_dataset: Dataset to rebatch.
      batch_sizes: A tf.int64 scalar or vector, representing the size of
        batches to produce. If this argument is a vector, these values are
        cycled through in order.
      drop_remainder: (Optional.) A tf.bool scalar tf.Tensor, representing
        whether the last batch should be dropped in the case it has fewer than
        batch_sizes[cycle_index] elements; the default behavior is not to drop
        the smaller batch.
    """
    self._input_dataset = input_dataset
    self._batch_sizes = ops.convert_to_tensor(
        batch_sizes, dtype=dtypes.int64, name="batch_sizes")
    self._drop_remainder = ops.convert_to_tensor(
        drop_remainder, dtype=dtypes.bool, name="drop_remainder")
    new_batch_dim = self._compute_static_batch_dim()

    self._element_spec = nest.map_structure(
        lambda ts: ts._unbatch()._batch(new_batch_dim),
        dataset_ops.get_structure(input_dataset))

    # auto_shard rewrite assumes that there's normalize_to_dense before
    # rebatch_dataset.
    # LINT.IfChange
    input_dataset = dataset_ops.normalize_to_dense(input_dataset)
    variant_tensor = ged_ops.rebatch_dataset_v2(
        input_dataset._variant_tensor,  
        batch_sizes=batch_sizes,
        drop_remainder=drop_remainder,
        **self._flat_structure)
    super(_RebatchDataset, self).__init__(input_dataset, variant_tensor)

  def _compute_static_batch_dim(self):
    """Computes the static batch dimension of a dataset if it can be determined.

    Given the _RebatchDataset parameters, determines the batch dimension of this
    dataset statically. Returns None if this cannot be determined or is
    variable.

    Returns:
      An integer representing the batch dimension of the dataset. If it cannot
      be determined statically, returns None.

    Raises:
      ValueError: The batch_sizes parameter is malformed, input_dataset is
      not batched, or input_dataset batch sizes are incompatible with each
      other.
    """
    new_batch_dim = tensor_util.constant_value(self._batch_sizes)
    if new_batch_dim is None:
      return None

    if isinstance(new_batch_dim, np.ndarray):
      if len(new_batch_dim.shape) == 1:
        if np.all(new_batch_dim == new_batch_dim[0]):
          new_batch_dim = new_batch_dim[0]
        else:
          return None
      elif len(new_batch_dim.shape) > 1:
        raise ValueError(
            f"Invalid batch_sizes. Expected batch_sizes to be a scalar or "
            f"a vector. Received batch_sizes of rank "
            f"{len(new_batch_dim.shape)}."
        )

    if self._may_form_partial_batches(new_batch_dim):
      return None

    return new_batch_dim

  def _may_form_partial_batches(self, desired_batch_size):
    """Returns whether this dataset may form partial batches."""
    if tensor_util.constant_value(self._drop_remainder):
      return False

    def get_batch_dim(type_spec):
      shape = type_spec._to_legacy_output_shapes()  
      if not isinstance(shape, tensor_shape.TensorShape):
        return None
      if shape.rank is None:
        return None
      return shape.dims[0].value

    input_batch_dims = [
        get_batch_dim(ts)
        for ts in nest.flatten(dataset_ops.get_structure(self._input_dataset))
    ]
    known_input_batch_dims = [d for d in input_batch_dims if d is not None]

    if not known_input_batch_dims:
      return True

    known_input_batch_dims = np.asarray(known_input_batch_dims)

    return known_input_batch_dims[0] % desired_batch_size != 0

  @property
  def element_spec(self):
    return self._element_spec

_AutoShardDataset 對資料集自動分片。

這個資料集接收了一個現有的資料集,並嘗試自動找出如何在多工作者的情況下使用圖來對資料集進行分片。

  • 如果 AutoShardPolicy 設定為 FILE,它就會沿著資料集圖向上走,直到找到一個讀者資料集(reader dataset),然後在該節點之前插入一個 ShardDataset op,這樣每個工作者只能看到一些檔案。

  • 如果 AutoShardPolicy 設定為 DATA,它會在輸入流水線的末端,在 terminal PrefetchDataset(如果有)之前,插入一個 ShardDataset 操作。此外,如果輸入管道中有 RebatchDatasetV2,出於正確性考慮,它將被寫入 legacy RebatchDataset,因為 RebatchDatasetV2 與資料分片不相容。

  • 如果 AutoShardPolicy 設定為 AUTO,它將嘗試進行基於檔案的分片。如果找不到讀者資料集,它就會退回到進行基於資料的分片。

  • 如果 AutoShardPolicy 被設定為 OFF,則不進行處理。

class _AutoShardDataset(dataset_ops.UnaryDataset):
  """A Dataset that shards the Dataset automatically.

  This dataset takes in an existing dataset and tries to automatically figure
  out how to shard the dataset in a multi-worker scenario using graph rewrites.

  If the AutoShardPolicy is set to FILE, it walks up the dataset graph until
  it finds a reader dataset, then inserts a ShardDataset op before that node
  so that each worker only sees some files.

  If the AutoShardPolicy is set to DATA, it inserts a ShardDataset op at the
  end of the input pipeline, before any terminal PrefetchDataset if there is
  one. Additionally, if there is a RebatchDatasetV2 in the input pipeline, it
  is written to legacy RebatchDataset for correctness reasons, since
  RebatchDatasetV2 is incompatible with data sharding.

  If the AutoShardPolicy is set to AUTO, it tries to do file-based sharding.
  If it cannot find a reader dataset, it falls back to doing data-based
  sharding.

  If the AutoShardPolicy is set to OFF, it does nothing.

  Attributes:
    num_workers: Total number of workers to shard this dataset across.
    index: The current worker index (out of the total number of workers) this
      dataset is for.
    num_replicas: The total number of replicas across all workers. This is used
      only when sharding by data (either DATA or AUTO) in order to rewrite
      RebatchDatasetV2 to RebatchDataset.

  Raises:
    NotFoundError: If we cannot find a suitable reader dataset to begin
      automatically sharding the dataset.
  """

  def __init__(self, input_dataset, num_workers, index, num_replicas=None):
    self._input_dataset = input_dataset

    self._element_spec = input_dataset.element_spec
    variant_tensor = ged_ops.auto_shard_dataset(
        self._input_dataset._variant_tensor, 
        num_workers=num_workers,
        index=index,
        auto_shard_policy=int(
            input_dataset.options().experimental_distribute.auto_shard_policy),
        num_replicas=num_replicas,
        **self._flat_structure)
    super(_AutoShardDataset, self).__init__(input_dataset, variant_tensor)

  @property
  def element_spec(self):
    return self._element_spec

在 tensorflow\core\grappler\optimizers\data\auto_shard.cc 之中有如下做自動分片的程式碼,有興趣的讀者可以自行深入。

Status ApplyAutoShard(const NodeDef& sink_node, int64_t num_workers,
                      int64_t index, AutoShardPolicy policy,
                      int64_t num_replicas, MutableGraphView* graph,
                      AutoShardPolicy* policy_applied) {
  *policy_applied = policy;
  FunctionLibraryDefinition flib(OpRegistry::Global(),
                                 graph->graph()->library());
  switch (policy) {
    case AutoShardPolicy::OFF:
      return Status::OK();
    case AutoShardPolicy::FILE:
      return ShardByFile(sink_node, num_workers, index, &flib, graph);
    case AutoShardPolicy::DATA:
      return ShardByData(sink_node, num_workers, index, num_replicas, graph);
    case AutoShardPolicy::HINT:
      return ShardByHint(sink_node, num_workers, index, num_replicas, graph);
    case AutoShardPolicy::AUTO:
    default:
      Status s = ShardByFile(sink_node, num_workers, index, &flib, graph);
      if (errors::IsNotFound(s)) {
        LOG(WARNING) << "AUTO sharding policy will apply DATA sharding policy "
                        "as it failed to apply FILE sharding policy because of "
                        "the following reason: "
                     << s.error_message();
        *policy_applied = AutoShardPolicy::DATA;
        return ShardByData(sink_node, num_workers, index, num_replicas, graph);
      }
      *policy_applied = AutoShardPolicy::FILE;
      return s;
  }
}

具體關係如下,DistributedDataset 成員變數 _cloned_datasets 列表包括了多個 _AutoShardDataset,每個針對一個 Worker。

圖 4 資料集關係

迭代資料

我們接下來看看 DistributedDataset 如何迭代,iter 方法會針對每個 worker 建立一個 iterator,最後統一返回一個 DistributedIterator。

def __iter__(self):

  canonicalize_devices = getattr(self._strategy, "_canonicalize_devices", True)

  # 會針對每個 worker 建立一個 iterator
  worker_iterators = _create_iterators_per_worker(
      self._cloned_datasets,
      self._input_workers,
      enable_legacy_iterators=False,
      options=self._options,
      canonicalize_devices=canonicalize_devices)
    
  # 統一返回一個 DistributedIterator  
  iterator = DistributedIterator(
      self._input_workers,
      worker_iterators,
      self._strategy,
      cardinality=self._cardinality,
      enable_get_next_as_optional=self._enable_get_next_as_optional,
      options=self._options)
  iterator._element_spec = self._element_spec  # pylint: disable=protected-access

  # When async eager is enabled, sometimes the iterator may not finish
  # initialization before passing to a multi device function, add a sync point
  # here to make sure all underlying iterators are initialized.
  if context.executing_eagerly():
    context.async_wait()

  return iterator

_create_iterators_per_worker 為每個 worker 建立一個 multidevice iterator。

def _create_iterators_per_worker(worker_datasets,
                                 input_workers,
                                 enable_legacy_iterators,
                                 options=None,
                                 canonicalize_devices=False):
  """Create a multidevice iterator on each of the workers."""
  iterators = []
  for i, worker in enumerate(input_workers.worker_devices):
    with ops.device(worker):
      worker_devices = input_workers.compute_devices_for_worker(i)
      if tf2.enabled() and not enable_legacy_iterators:
        iterator = _SingleWorkerOwnedDatasetIterator(
            dataset=worker_datasets[i],
            worker=worker,
            devices=worker_devices,
            options=options,
            canonicalize_devices=canonicalize_devices)
      else:
        iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker,
                                                worker_devices, options)
      iterators.append(iterator)
  return iterators

_SingleWorkerDatasetIterator 則會建立 MultiDeviceIterator。

class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase):
  """Iterator for a single DistributedDatasetV1 instance."""

  def _make_iterator(self):
    """Make appropriate iterator on the dataset."""
    with ops.device(self._worker):
      if self._options is not None:
        self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
            self._dataset,
            self._devices,
            max_buffer_size=self._options.experimental_per_replica_buffer_size,
            prefetch_buffer_size=self._options
            .experimental_per_replica_buffer_size)
      else:
        self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
            self._dataset,
            self._devices,
        )

  def initialize(self):
    """Initialize underlying iterator.

    In eager execution, this simply recreates the underlying iterator.
    In graph execution, it returns the initializer ops for the underlying
    iterator.

    Returns:
      A list of any initializer ops that should be run.
    """
    if ops.executing_eagerly_outside_functions():
      self._iterator._eager_reset()  # pylint: disable=protected-access
      return []
    else:
      return [self._iterator.initializer]

  @property
  def output_classes(self):
    return dataset_ops.get_legacy_output_classes(self._iterator)

  @property
  def output_shapes(self):
    return dataset_ops.get_legacy_output_shapes(self._iterator)

  @property
  def output_types(self):
    return dataset_ops.get_legacy_output_types(self._iterator)

具體邏輯如下:

圖 5 獲取迭代器

2.1.7 DistributedIterator

我們接下來看 DistributedIterator。

DistributedIterator

DistributedIterator 其實沒有完成多少實際工作,主要功能是在基類 DistributedIteratorBase。

class DistributedIterator(DistributedIteratorBase,
                          composite_tensor.CompositeTensor):
  """Input Iterator for a distributed dataset."""

  def __init__(self,
               input_workers=None,
               iterators=None,
               strategy=None,
               components=None,
               element_spec=None,
               cardinality=cardinality_lib.UNKNOWN,
               enable_get_next_as_optional=False,
               options=None):

    error_message = ("Either input_workers or "
                     "both components and element_spec need to be "
                     "provided.")
    self._options = options

    if iterators is None:
      if (components is None or element_spec is None):
        raise ValueError(error_message)
      self._element_spec = element_spec
      self._input_workers = input_workers
      self._iterators = components
      self._strategy = strategy
      self._cardinality = cardinality
      self._enable_get_next_as_optional = enable_get_next_as_optional
    else:
      if (components is not None and element_spec is not None):
        raise ValueError(error_message)

      super(DistributedIterator,
            self).__init__(input_workers, iterators, strategy, cardinality,
                           enable_get_next_as_optional)

  @property
  def element_spec(self):
    # When partial batch handling is enabled, always set the batch dimension to
    # None, otherwise we just follow element_spec of the underlying dataset
    # (whose batch dimension may also be None). This is because with partial
    # batching handling we could always produce empty batches.
    if (self._enable_get_next_as_optional and
        self._strategy.extended._in_multi_worker_mode()):  # pylint: disable=protected-access
      return nest.map_structure(
          _rebatch_as_dynamic, self._element_spec, expand_composites=False)
    return self._element_spec

  @property
  def _type_spec(self):
    # Note that we use actual element_spec instead of the rebatched-as-dynamic
    # one to create DistributedIteratorSpec, to be consistent with the
    # underlying iterators' specs.
    return DistributedIteratorSpec(self._input_workers, self._element_spec,
                                   self._strategy,
                                   self._options,
                                   self._cardinality,
                                   self._enable_get_next_as_optional)

DistributedIteratorBase

DistributedIteratorBase 的主要方法和普通迭代器相同。

class DistributedIteratorBase(DistributedIteratorInterface):
  """Common implementation for all input iterators."""

  # pylint: disable=super-init-not-called
  def __init__(self, input_workers, iterators, strategy, cardinality,
               enable_get_next_as_optional):

    self._iterators = iterators
    self._input_workers = input_workers
    self._strategy = strategy
    self._cardinality = cardinality
    self._enable_get_next_as_optional = enable_get_next_as_optional

  def next(self):
    return self.__next__()

  def __next__(self):
    try:
      return self.get_next()
    except errors.OutOfRangeError:
      raise StopIteration

  def __iter__(self):
    return self

get_next 完成了獲取資料功能,具體我們關注一下 _create_per_replica,這裡看起來和分散式最為相關,具體是:

  • 找到所有 worker 資訊。
  • 計算副本數目。
  • 獲取資料,並且重新組合。
def get_next(self, name=None):
  """Returns the next input from the iterator for all replicas."""
  with distribution_strategy_context.enter_or_assert_strategy(
      self._strategy):

  if not self._enable_get_next_as_optional:
    return self._get_next_no_partial_batch_handling(name)

  optional_list = []
  # 找到 worker 資訊
  for i, worker in enumerate(self._input_workers.worker_devices):
    with ops.device(worker):
      optional_list.append(self._iterators[i].get_next_as_optional_list())
      
  # 計算副本數目    
  num_replicas_with_values = _calculate_replicas_with_values(
      self._strategy, self._input_workers, optional_list)

  # 獲取資料,並且重新組合
  def _value_or_dummy():
    value_list = _get_value_or_dummy( # 獲取資料
        self._input_workers, optional_list, produce_dummy=True)
    return _create_per_replica(value_list, self._strategy)

  def _eof():
    # Optional.get_value raises InvalidArgumentError when there's no value,
    # so we need to call GetNext to raise EOFError.
    return self._get_next_no_partial_batch_handling()

  return control_flow_ops.cond(
      num_replicas_with_values > 0, _value_or_dummy, _eof, strict=True)

_calculate_replicas_with_values 計算出有資料的副本數目。

def _calculate_replicas_with_values(strategy, input_workers, optional_list):
  """Calcualates the number of replicas that have values.

  Args:
    strategy: the tf.distribute.Strategy.
    input_workers: the InputWorkers.
    optional_list: a list of lists tf.experimental.Optional. The values from
      each compute device grouped by the input device.

  Returns:
    A scalar Tensor.
  """
  worker_has_values = []
  for worker, optionals in zip(input_workers.worker_devices, optional_list):
    with ops.device(worker):
      device_has_values = [
          math_ops.cast(v.has_value(), dtypes.int64) for v in optionals
      ]
      worker_has_values.append(
          math_ops.reduce_sum(device_has_values, keepdims=True))
  client_has_values = math_ops.reduce_sum(worker_has_values, keepdims=True)
  if strategy.extended._in_multi_worker_mode():  
    global_has_values = strategy.reduce(
        reduce_util.ReduceOp.SUM, client_has_values, axis=None)
    return array_ops.reshape(global_has_values, [])
  else:
    return array_ops.reshape(client_has_values, [])

_get_value_or_dummy 獲取具體資料。

def _get_value_or_dummy(input_workers, optional_list, produce_dummy):
  """Returns the value of the optionals or dummy values.

  Args:
    input_workers: the InputWorkers.
    optional_list: a list of lists tf.experimental.Optional. The values from
      each compute device grouped by the input device.
    produce_dummy: a bool. Whether to produce dummy tensors when the optional
      doesn't have a value.

  Returns:
    A flatten list of Tensors.

  """
  value_list = []
  for i, worker in enumerate(input_workers.worker_devices): # 遍歷 worker
    with ops.device(worker):
      devices = input_workers.compute_devices_for_worker(i) # 遍歷 worker 之中的裝置
      for j, device in enumerate(devices):
        with ops.device(device):
          if produce_dummy:
            value_list.append( # 累計資料
                control_flow_ops.cond(
                    optional_list[i][j].has_value(),
                    lambda: optional_list[i][j].get_value(),  
                    lambda: _dummy_tensor_fn(optional_list[i][j].element_spec),
                    strict=True,
                ))
          else:
            value_list.append(optional_list[i][j].get_value())
  return value_list

_create_per_replica 完成了具體資料的重新組合。

  • 對於 OneDeviceStrategy 以外的策略,它會建立一個 PerReplica,其型別規格被設定為資料集的元素規格。這有助於避免對部分批次進行回溯。當多個客戶在不同的時間回溯時,回溯對於多客戶端來說是有問題的,因為回溯改變了 tf.function 的集合鍵(collective keys),並導致客戶之間的不匹配。
  • 對於單客戶策略,_create_per_replica 只是呼叫 distribution_utils.regroup()。
def _create_per_replica(value_list, strategy):
  """Creates a PerReplica.

  For strategies other than OneDeviceStrategy, it creates a PerReplica whose
  type spec is set to the element spec of the dataset. This helps avoid
  retracing for partial batches. Retracing is problematic for multi client when
  different client retraces different time, since retracing changes the
  collective keys in the tf.function, and causes mismatches among clients.

  For single client strategies, this simply calls distribute_utils.regroup().

  Args:
    value_list: a list of values, one for each replica.
    strategy: the tf.distribute.Strategy.

  Returns:
    a structure of PerReplica.

  """
  always_wrap = _always_wrap(strategy)
  per_replicas = distribute_utils.regroup(value_list, always_wrap=always_wrap)
  return per_replicas

具體邏輯如下:

圖 6 迭代器處理資料

至此,對於讀取資料我們其實已經有了一個比較基礎的分析,其中最主要幾個類之間的邏輯如下:

  • InputWorker 的作用是從輸入 worker 裝置到計算裝置的 1-to-many mapping,可以認為 InputWorker 把 worker 繫結到裝置之上。
  • DistributedDataset 就是資料集了,其內部有一系列複雜處理,首先把資料集複製到一系列裝置上,然後對資料集進行一系列加強,首先是 _RemoteDataset,然後升級到 _AutoShardDataset。
  • DistributedDataset 的 iter 方法會針對每個 worker 建立一個 iterator,最後統一返回一個 DistributedIterator。
  • DistributedIterator 的 get_next 方法完成了獲取資料功能,用 _create_per_replica 來舉例, 具體操作是:
    • 找到所有 worker 資訊。
    • 計算副本數目。
    • 獲取資料,並且重新組合。

具體如下圖(只是大致邏輯概念,僅僅為了更好的說明),數字與下圖之中對應。

  • 1.InputWorkers 提供了worker和裝置的對映關係。
  • 2.資料集被分配到各個裝置或者說worker之上。
  • 3.每個 worker 建立一個 iterator,最後統一返回一個 DistributedIterator。

2.2 通過方法初始化

如果上述的批量分割和資料集分片邏輯(即,直接讀取資料集邏輯)不能滿足需求,使用者可以使用tf.distribution.Strategy.distribution_datasets_from_function,它不會做任何自動的批量分割或分片。

2.2.1 StrategyBase

我們首先來到 StrategyBase。distribute_datasets_from_function 會分發 tf.data.Dataset,這些例項是通過執行dataset_fn 來建立的。

使用者傳入的引數 dataset_fn 是一個輸入函式,它有 tf.distribution.InputContext 引數,並返回一個 tf.data.Dataset 例項。從 dataset_fn 返回的資料集預設已經按每個副本的批處理量(即全域性批處理量除以同步的副本數量)進行分批,也進行了分片處理。

tf.distribution.Strategy.distribution_datasets_from_function 不會對輸入函式返回的 tf.dataset 例項進行批處理或分片。dataset_fn 將在每個 worker 的 CPU 裝置上被呼叫,每次呼叫都會生成一個數據集,在此呼叫之中,該 worker 的每個副本都會從資料集獲取一批輸入(例如,如果一個 worke r有兩個副本,每一步都會從 Dataset 中提取兩個批次)。在幾種情況下會使用這個方法。

  • 首先,它允許您指定您自己的批處理和分片邏輯,相比之下,tf.distribution.experimental_distribute_dataset 會為您做批處理和分片。例如,當 experimental_distribute_dataset 無法對輸入檔案進行分片時,可以用這個方法來手動分片(避免 experimental_distribute_dataset 的緩慢回退行為)。
  • 在無限資料集的情況下,可以通過建立資料集副本來完成分片,這些副本只在隨機種子上有所不同。

dataset_fn 應該接受一個 tf.distribution.InputContext 例項,此例項包括了關於批處理和輸入副本的資訊。由dataset_fn 返回的 tf.data.Dataset 應該有一個每副本(per-replica)的批次大小,這與 experimental_distribute_dataset 不同,後者使用全域性批次大小。全域性批次大小可以通過 input_context.get_per_replica_batch_size 來計算得到。

def distribute_datasets_from_function(self, dataset_fn, options=None):
  """
  Args:
    dataset_fn: A function taking a tf.distribute.InputContext instance and
      returning a tf.data.Dataset.
    options: tf.distribute.InputOptions used to control options on how this
      dataset is distributed.

  Returns:
    A tf.distribute.DistributedDataset.
  """
  distribution_strategy_input_api_counter.get_cell(
      self.__class__.__name__,
      "distribute_datasets_from_function").increase_by(1)

  return self._extended._distribute_datasets_from_function(  
      dataset_fn, options)

2.2.2 MirroredStrategy

我們依然用MirroredStrategy作為例子來看。_distribute_datasets_from_function 這裡會初始化 Input worker,然後配置上下文,讀取資料。

def _distribute_datasets_from_function(self, dataset_fn, options):
  input_workers = self._input_workers_with_options(options) # 構建 InputWorkers
  input_contexts = []
  num_workers = input_workers.num_workers
  for i in range(num_workers):
    input_contexts.append(distribute_lib.InputContext(
        num_input_pipelines=num_workers,
        input_pipeline_id=i,
        num_replicas_in_sync=self._num_replicas_in_sync))

  return input_lib.get_distributed_datasets_from_function(
      dataset_fn, input_workers, input_contexts, self._container_strategy(),
      options)

2.2.3 建立 InputWorkers

_input_workers_with_options 建立了 InputWorkers。

def _input_workers_with_options(self, options=None):
  if not options:
    return input_lib.InputWorkers(self._input_workers_devices)
  if (options.experimental_replication_mode ==
      distribute_lib.InputReplicationMode.PER_REPLICA):
    if options.experimental_place_dataset_on_device:
      self._input_workers_devices = (
          tuple(
              (device_util.canonicalize(d, d), (d,)) for d in self._devices))
    else:
      self._input_workers_devices = (
          tuple((device_util.canonicalize("/device:CPU:0", d), (d,))
                for d in self._devices))
    return input_lib.InputWorkers(self._input_workers_devices)
  else:
    if not options.experimental_fetch_to_device:
      return input_lib.InputWorkers([
          (host_device, (host_device,) * len(compute_devices))
          for host_device, compute_devices in self._input_workers_devices
      ])
    else:
      return input_lib.InputWorkers(self._input_workers_devices)

2.2.4 input_contexts

input_contexts 是一個包裝輸入函式所需資訊的類,是一個傳遞給使用者輸入函式的上下文類,包含了關於計算副本和輸入流水線的資訊。

  • 利用計算副本的數量(同步訓練中)可以讓我們從每個副本所需的全域性批次大小中計算出本地批次大小。

  • 利用輸入流水線的資訊則可以用來在每個副本中返回不同的輸入子集(例如,分片輸入流水線,使用不同的 input 源等)。

@tf_export("distribute.InputContext")
class InputContext(object):
  """A class wrapping information needed by an input function.

  This is a context class that is passed to the user's input function and
  contains information about the compute replicas and input pipelines. The
  number of compute replicas (in sync training) helps compute the local batch
  size from the desired global batch size for each replica. The input pipeline
  information can be used to return a different subset of the input in each
  replica (for e.g. shard the input pipeline, use a different input
  source etc).
  """

  __slots__ = [
      "_num_input_pipelines", "_input_pipeline_id", "_num_replicas_in_sync"
  ]

  def __init__(self,
               num_input_pipelines=1,
               input_pipeline_id=0,
               num_replicas_in_sync=1):
    """Initializes an InputContext object.

    Args:
      num_input_pipelines: the number of input pipelines in a cluster.
      input_pipeline_id: the current input pipeline id, should be an int in
        [0,num_input_pipelines).
      num_replicas_in_sync: the number of replicas that are in sync.
    """
    self._num_input_pipelines = num_input_pipelines
    self._input_pipeline_id = input_pipeline_id
    self._num_replicas_in_sync = num_replicas_in_sync

  @property
  def num_replicas_in_sync(self):
    """Returns the number of compute replicas in sync."""
    return self._num_replicas_in_sync

  @property
  def input_pipeline_id(self):
    """Returns the input pipeline ID."""
    return self._input_pipeline_id

  @property
  def num_input_pipelines(self):
    """Returns the number of input pipelines."""
    return self._num_input_pipelines

  def get_per_replica_batch_size(self, global_batch_size):
    """Returns the per-replica batch size.

    Args:
      global_batch_size: the global batch size which should be divisible by
        num_replicas_in_sync.

    Returns:
      the per-replica batch size.

    Raises:
      ValueError: if global_batch_size not divisible by
        num_replicas_in_sync.
    """
    if global_batch_size % self._num_replicas_in_sync != 0:
      raise ValueError("The global_batch_size %r is not divisible by "
                       "num_replicas_in_sync %r " %
                       (global_batch_size, self._num_replicas_in_sync))
    return global_batch_size // self._num_replicas_in_sync

  def __str__(self):
    return "tf.distribute.InputContext(input pipeline id {}, total: {})".format(
        self.input_pipeline_id, self.num_input_pipelines)

2.2.5 返回資料集

get_distributed_datasets_from_function 從給定的輸入函式返回一個分散式資料集。這是一個通用函式,所有策略都使用它來返回分散式資料集。取決於在 TF 1 還是 TF 2 的背景下而返回不同的分散式資料集例項,從而分散式資料集例項的 API 也有所不同。

def get_distributed_datasets_from_function(dataset_fn,
                                           input_workers,
                                           input_contexts,
                                           strategy,
                                           options=None):
  """Returns a distributed dataset from the given input function.

  This is a common function that is used by all strategies to return a
  distributed dataset. The distributed dataset instance returned is different
  depending on if we are in a TF 1 or TF 2 context. The distributed dataset
  instances returned differ from each other in the APIs supported by each of
  them.

  Args:
    dataset_fn: a function that returns a tf.data.Dataset instance.
    input_workers: an InputWorkers object which specifies devices on which
        iterators should be created.
    input_contexts: A list of InputContext instances to be passed to call(s)
        to dataset_fn. Length and order should match worker order in
        worker_device_pairs.
    strategy: a tf.distribute.Strategy object, used to run all-reduce to
        handle last partial batch.
    options: Default is None. tf.distribute.InputOptions used to control
        options on how this dataset is distributed.

  Returns:
    A distributed dataset instance.

  Raises:
    ValueError: if options.experimental_replication_mode and
    options.experimental_place_dataset_on_device are not consistent
  """

  if tf2.enabled():
    return DistributedDatasetsFromFunction(input_workers, strategy,
                                           input_contexts, dataset_fn, options)
  else:
    return DistributedDatasetsFromFunctionV1(input_workers, strategy,
                                             input_contexts, dataset_fn,
                                             options)

2.2.6 構建資料集

DistributedDatasetsFromFunction 會呼叫 _create_datasets_from_function_with_input_context。

class DistributedDatasetsFromFunction(_IterableInput,
                                      composite_tensor.CompositeTensor):
  """Inputs created from dataset function."""

  def __init__(self,
               input_workers,
               strategy,
               input_contexts=None,
               dataset_fn=None,
               options=None,
               components=None,
               element_spec=None):
    """Makes an iterable from datasets created by the given function.

    Args:
      input_workers: an InputWorkers object.
      strategy: a tf.distribute.Strategy object, used to run all-reduce to
        handle last partial batch.
      input_contexts: A list of InputContext instances to be passed to call(s)
        to dataset_fn. Length and order should match worker order in
        worker_device_pairs.
      dataset_fn: A function that returns a Dataset given an InputContext.
        Either dataset_fn or components should be passed to construct
        DistributedDatasetsFromFunction. Use this when constructing
        DistributedDataset using a function. Use components when constructing
        using DistributedDatasetsFromFunctionSpec.
      options: tf.distribute.InputOptions used to control options on how this
        dataset is distributed.
      components: datasets when DistributedDatasetsFromFunction is constructed
        from DistributedDatasetsFromFunctionSpec. Only one of dataset or
        components should be passed.
      element_spec: element spec for DistributedDataset when constructing from
        DistributedDatasetSpec. This will be used to set the element_spec for
        DistributedDatasetsFromFunctionSpec and verified against element_spec
        from components.
    """
    super(DistributedDatasetsFromFunction, self).__init__(
        input_workers=input_workers)
    self._input_workers = input_workers
    self._strategy = strategy
    self._options = options

    if dataset_fn is not None:
      self._datasets, element_spec = (
          _create_datasets_from_function_with_input_context(
              input_contexts, self._input_workers, dataset_fn))
      self._element_spec = _create_distributed_tensor_spec(
          self._strategy, element_spec)
    else:
      self._element_spec = element_spec
      self._datasets = components

    self._enable_get_next_as_optional = _enable_get_next_as_optional(
        self._strategy, self._datasets[0])

_create_datasets_from_function_with_input_context 函式會正式構建資料集。

def _create_datasets_from_function_with_input_context(input_contexts,
                                                      input_workers,
                                                      dataset_fn):
  """Create device datasets per worker given a dataset function."""
  datasets = []
  for i, ctx in enumerate(input_contexts): #遍歷上下文
    worker = input_workers.worker_devices[i] # 遍歷 worker
    with ops.device(worker):
      dataset = dataset_fn(ctx) # 獲取資料
      datasets.append(dataset)
  return datasets, dataset.element_spec

具體邏輯如下:

圖 7 通過方法構建資料

2.3 高層使用

2.3.1 Keras

我們首先看看 Keras 之中的使用。在 tensorflow/python/keras/distribute/distributed_training_utils_v1.py 之中有如下方法,這裡會生成策略的資料迭代器。

def get_iterator(dataset, distribution_strategy):
  with distribution_strategy.scope():
    iterator = distribution_strategy.make_dataset_iterator(dataset)
  initialize_iterator(iterator, distribution_strategy)
  return iterator

tensorflow/python/distribute/distribute_lib.py 則會使用 _extended,比如 StrategyBase 有:

def make_dataset_iterator(self, dataset):
  """DEPRECATED TF 1.x ONLY."""
  return self._extended._make_dataset_iterator(dataset)  

對於 StrategyV1 有:

def make_dataset_iterator(self, dataset):
  """Makes an iterator for input provided via dataset.

  DEPRECATED: This method is not available in TF 2.x.

  Data from the given dataset will be distributed evenly across all the
  compute replicas. We will assume that the input dataset is batched by the
  global batch size. With this assumption, we will make a best effort to
  divide each batch across all the replicas (one or more workers).
  If this effort fails, an error will be thrown, and the user should instead
  use make_input_fn_iterator which provides more control to the user, and
  does not try to divide a batch across replicas.

  The user could also use make_input_fn_iterator if they want to
  customize which input is fed to which replica/worker etc.

  Args:
    dataset: tf.data.Dataset that will be distributed evenly across all
      replicas.

  Returns:
    An tf.distribute.InputIterator which returns inputs for each step of the
    computation.  User should call initialize on the returned iterator.
  """
  return self._extended._make_dataset_iterator(dataset) 

來到 tensorflow/python/distribute/mirrored_strategy.py,則有如下程式碼生成 DatasetIterator:

def _make_dataset_iterator(self, dataset):
  return input_lib.DatasetIterator(
      dataset,
      self._input_workers,
      self._container_strategy(),
      num_replicas_in_sync=self._num_replicas_in_sync)

具體邏輯如下:

圖 9 Keras 使用

2.2.2 其他路徑

另一條執行路徑如下:

def single_loss_example(optimizer_fn, distribution, use_bias=False,
                        iterations_per_step=1):
  """Build a very simple network to use in tests and examples."""

  def dataset_fn():
    return dataset_ops.Dataset.from_tensors([[1.]]).repeat()

  optimizer = optimizer_fn()
  layer = core.Dense(1, use_bias=use_bias)

  def loss_fn(ctx, x):
    del ctx
    y = array_ops.reshape(layer(x), []) - constant_op.constant(1.)
    return y * y

  single_loss_step = step_fn.StandardSingleLossStep(
      dataset_fn, loss_fn, optimizer, distribution, iterations_per_step)

StandardSingleLossStep 呼叫如下:

class StandardSingleLossStep(StandardInputStep):
  """A step function that implements a training step for a feed forward network.

  An instance of this class is intended to be used as a callable:

  ```python
  ...
  step = step_fn.StandardSingleLossStep(
      dataset, loss_fn, optimizer, distribution)

  # Run a single training step on a given DistributionStrategy:
  step(distribution)
  ...
  ```

  Args:
    dataset_fn: a function that returns a tf.data Dataset that produces the
      input for the model.
    loss_fn: a function that takes a context and inputs as arguments. It returns
      the loss for those inputs. context is an instance of
      values.MultiStepContext that will be passed when loss_fn is run.
      context can be used to specify the outputs to be returned from
      loss_fn, among other things.
    optimizer: an optimizer that implements an update rule.
    distribution: a DistributionStrategy object.
  """

  def __init__(self, dataset_fn, loss_fn, optimizer, distribution,
               iterations_per_step=1):
    super(StandardSingleLossStep, self).__init__(dataset_fn, distribution)
    self._loss_fn = loss_fn
    self._optimizer = optimizer
    self._iterations_per_step = iterations_per_step

  def __call__(self):
    with self._distribution.scope():
      def step_fn(ctx, inputs):
        """Function to run one iteration with one input."""
        gradients_fn = backprop.implicit_grad(self._loss_fn)
        gradients_fn = optimizer_lib.get_filtered_grad_fn(gradients_fn)

        grads_and_vars = self.distribution.extended.call_for_each_replica(
            gradients_fn, args=(ctx, inputs))
        # If threads use layers, then we need to run the first step
        # sequentially, so that layers.build() is not executed in parallel.
        # Otherwise, multiple sets of mirrored variables are going to be
        # created.
        return self._optimizer._distributed_apply(  # pylint: disable=protected-access
            self.distribution, grads_and_vars)

      ctx = self.distribution.extended.experimental_run_steps_on_iterator(
          step_fn, self._iterator, self._iterations_per_step)
      return ctx.run_op

StandardInputStep 這裡生成了 _iterator。

class StandardInputStep(Step):
  """Step with a standard implementation of input handling.

  Args:
    dataset_fn: a function that returns a tf.data Dataset that produces the
      input for the model.
  """

  def __init__(self, dataset_fn, distribution):
    super(StandardInputStep, self).__init__(distribution)
    self._iterator = distribution.make_input_fn_iterator(lambda _: dataset_fn())

  def initialize(self):
    return self._iterator.initializer

StrategyV1之中有:

def make_input_fn_iterator(self,  # pylint: disable=useless-super-delegation
                           input_fn,
                           replication_mode=InputReplicationMode.PER_WORKER):
  """Returns an iterator split across replicas created from an input function.

  DEPRECATED: This method is not available in TF 2.x.

  The input_fn should take an tf.distribute.InputContext object where
  information about batching and input sharding can be accessed:

  ```
  def input_fn(input_context):
    batch_size = input_context.get_per_replica_batch_size(global_batch_size)
    d = tf.data.Dataset.from_tensors([[1.]]).repeat().batch(batch_size)
    return d.shard(input_context.num_input_pipelines,
                   input_context.input_pipeline_id)
  with strategy.scope():
    iterator = strategy.make_input_fn_iterator(input_fn)
    replica_results = strategy.experimental_run(replica_fn, iterator)
  ```

  The tf.data.Dataset returned by input_fn should have a per-replica
  batch size, which may be computed using
  input_context.get_per_replica_batch_size.

  Args:
    input_fn: A function taking a tf.distribute.InputContext object and
      returning a tf.data.Dataset.
    replication_mode: an enum value of tf.distribute.InputReplicationMode.
      Only PER_WORKER is supported currently, which means there will be
      a single call to input_fn per worker. Replicas will dequeue from the
      local tf.data.Dataset on their worker.

  Returns:
    An iterator object that should first be .initialize()-ed. It may then
    either be passed to strategy.experimental_run() or you can
    iterator.get_next() to get the next value to pass to
    strategy.extended.call_for_each_replica().
  """
  return super(StrategyV1, self).make_input_fn_iterator(
      input_fn, replication_mode)

StrategyBase 之中有:

@doc_controls.do_not_generate_docs  # DEPRECATED: TF 1.x only
def make_input_fn_iterator(self,
                           input_fn,
                           replication_mode=InputReplicationMode.PER_WORKER):
  """DEPRECATED TF 1.x ONLY."""
  if replication_mode != InputReplicationMode.PER_WORKER:
    raise ValueError(
        "Input replication mode not supported: %r" % replication_mode)
  with self.scope():
    return self.extended._make_input_fn_iterator(  # pylint: disable=protected-access
        input_fn, replication_mode=replication_mode)

最終來到 MirroredStrategy,生成了 InputFunctionIterator,其呼叫關係如下:

class InputFunctionIterator(DistributedIteratorV1)
class DistributedIteratorV1(DistributedIteratorBase)

具體程式碼如下:

def _make_input_fn_iterator(
    self,
    input_fn,
    replication_mode=distribute_lib.InputReplicationMode.PER_WORKER):
  input_contexts = []
  num_workers = self._input_workers.num_workers
  for i in range(num_workers):
    input_contexts.append(distribute_lib.InputContext(
        num_input_pipelines=num_workers,
        input_pipeline_id=i,
        num_replicas_in_sync=self._num_replicas_in_sync))
  return input_lib.InputFunctionIterator(input_fn, self._input_workers,
                                         input_contexts,
                                         self._container_strategy())

邏輯如下:

圖 10 使用示例

0xFF 參考

tensorflow原始碼解析之distributed_runtime

TensorFlow分散式訓練

TensorFlow核心剖析

原始碼

Tensorflow分散式原理理解

TensorFlow架構與設計:概述

Tensorflow 跨裝置通訊

TensorFlow 篇 | TensorFlow 2.x 分散式訓練概覽