1. 程式人生 > 其它 >[原始碼解析] TensorFlow 分散式環境(6) --- Master 動態邏輯

[原始碼解析] TensorFlow 分散式環境(6) --- Master 動態邏輯

[原始碼解析] TensorFlow 分散式環境(6) --- Master 動態邏輯

目錄

在具體介紹 TensorFlow 分散式的各種 Strategy 之前,我們首先需要看看分散式的基礎:分散式環境。只有把基礎打紮實了,才能在以後的分析工作之中最大程度的掃清障礙,事半功倍。本文會從 Client 開始,看看 Master 如何對計算圖進行處理。

本文依舊深度借鑑了兩位大神:

本系列其他文章是:

[翻譯] TensorFlow 分散式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

[翻譯] TensorFlow 分散式之論文篇 "Implementation of Control Flow in TensorFlow"

[原始碼解析] TensorFlow 分散式環境(1) --- 總體架構

[原始碼解析] TensorFlow 分散式環境(2)---Master 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(3)--- Worker 靜態邏輯

[原始碼解析] TensorFlow 分散式環境(4) --- WorkerCache

[原始碼解析] TensorFlow 分散式環境(5) --- Session

1. GrpcSession

1.1 執行

首先,客戶會呼叫 GrpcSession 來開始執行,而 Run 方法會呼叫 RunHelper。

Status GrpcSession::Run(const RunOptions& run_options,
                        const std::vector<std::pair<string, Tensor>>& inputs,
                        const std::vector<string>& output_tensor_names,
                        const std::vector<string>& target_node_names,
                        std::vector<Tensor>* outputs,
                        RunMetadata* run_metadata) {
  return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
                   outputs, run_metadata, /* prun_handle */ "");
}

RunHelper 方法如下,這裡重要的是新增 feed 和 fetch,然後呼叫 RunProto 執行 session。

Status GrpcSession::RunHelper(
    const RunOptions& run_options,
    const std::vector<std::pair<string, Tensor>>& inputs,
    const std::vector<string>& output_tensor_names,
    const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
    RunMetadata* run_metadata, const string& prun_handle) {
  // Convert to proto
  std::unique_ptr<MutableRunStepRequestWrapper> req(
      master_->CreateRunStepRequest());
  std::unique_ptr<MutableRunStepResponseWrapper> resp(
      master_->CreateRunStepResponse());

  *req->mutable_options() = run_options;

  if (run_options.timeout_in_ms() == 0) {
    req->mutable_options()->set_timeout_in_ms(
        options_.config.operation_timeout_in_ms());
  }

  if (!prun_handle.empty()) {
    req->set_partial_run_handle(prun_handle);
  }

  for (const auto& it : inputs) {
    req->add_feed(it.first, it.second);
  }

  // Support long error messages by storing the error code in the response body.
  req->set_store_errors_in_response_body(true);

  // Build an index from fetch tensor name to first index in
  // output_tensor_names.
  std::unordered_map<string, int> output_name_to_offset;
  for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
    const string& name = output_tensor_names[i];
    if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
      req->add_fetch(name);
    }
  }
  for (const string& target : target_node_names) {
    req->add_target(target);
  }

  CallOptions call_options;
  call_options.SetTimeout(req->options().timeout_in_ms());
  
  // 呼叫 RunProto 執行session
  TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get()));

  // Look for an extended error returned in the response body.
  if (resp->status_code() != error::Code::OK) {
    return resp->status();
  }

  if (!output_tensor_names.empty()) {
    outputs->resize(output_tensor_names.size());
  }

  // Convert response back to Tensors in the correct order.
  for (size_t i = 0; i < resp->num_tensors(); ++i) {
    auto fetch_it = output_name_to_offset.find(resp->tensor_name(i));
    if (fetch_it == output_name_to_offset.end()) {
      return errors::Internal("Received response for unrequested fetch: ",
                              resp->tensor_name(i));
    }

    Tensor output;
    TF_RETURN_IF_ERROR(resp->TensorValue(i, &output));
    (*outputs)[fetch_it->second] = output;
  }
  // In the unlikely event that output_tensor_names contains duplicates, fill in
  // the duplicate values.
  if (output_name_to_offset.size() != output_tensor_names.size()) {
    for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
      const string& name = output_tensor_names[i];
      int offset = output_name_to_offset[name];
      if (offset != i) {
        (*outputs)[i] = (*outputs)[offset];
      }
    }
  }

  if (run_metadata) {
    run_metadata->Swap(resp->mutable_metadata());
  }

  return Status::OK();
}

最終 RunProto 還是呼叫到 master_->RunStep 完成業務功能。

Status GrpcSession::RunProto(CallOptions* call_options,
                             MutableRunStepRequestWrapper* req,
                             MutableRunStepResponseWrapper* resp) {
  string handle;
  TF_RETURN_IF_ERROR(Handle(&handle));
  req->set_session_handle(handle);
  return master_->RunStep(call_options, req, resp);
}

master_ 就是 GrpcRemoteMaster,所以我們接著看下去。

1.2 GrpcRemoteMaster

GrpcRemoteMaster 是位於 Client 的 gRPC 客戶端實現,它的 RunStep 方法只是通過 gRPC stub 來呼叫 遠端服務 MasterService 的 RunStep 介面,其實就是傳送一個 RunStepRequest 請求。

Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request,
               MutableRunStepResponseWrapper* response) override {
  return CallWithRetry(call_options, &request->ToProto(),
                       get_proto_from_wrapper(response),
                       &MasterServiceStub::RunStep, "RunStep/Client");
}

於是,此時 Client 的邏輯拓展如下:

圖 1 Master 動態邏輯 1

2. Master

從現在開始,我們進入到了 Master 角色對應的伺服器。GrpcMasterService 執行的是 gRPC 服務,當收到 RunStepRequest 時候,系統會呼叫到 RunStepHandler。程式碼位於:tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc。

// RPC handler for running one step in a session.
void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
  auto* trace = TraceRpc("RunStep/Server", call->client_metadata());
  CallOptions* call_opts = new CallOptions;
  if (call->request.options().timeout_in_ms() > 0) {
    call_opts->SetTimeout(call->request.options().timeout_in_ms());
  } else {
    call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
  }
  RunStepRequestWrapper* wrapped_request =
      new ProtoRunStepRequest(&call->request);
  MutableRunStepResponseWrapper* wrapped_response =
      new NonOwnedProtoRunStepResponse(&call->response);
  call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
  master_impl_->RunStep(
      call_opts, wrapped_request, wrapped_response,
      [call, call_opts, wrapped_request, trace](const Status& status) {
        call->ClearCancelCallback();
        delete call_opts;
        delete wrapped_request;
        delete trace;
        if (call->request.store_errors_in_response_body() && !status.ok()) {
          call->response.set_status_code(status.code());
          call->response.set_status_error_message(status.error_message());
          call->SendResponse(ToGrpcStatus(Status::OK()));
        } else {
          call->SendResponse(ToGrpcStatus(status));
        }
      });
  ENQUEUE_REQUEST(RunStep, true);
}

master_impl_ 是 Master 例項,RunStep 會呼叫master session進行計算。

void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
                     MutableRunStepResponseWrapper* resp, MyClosure done) {

  // 獲取session
  auto session = FindMasterSession(req->session_handle());

  // 執行session
  SchedClosure([this, start_time, session, opts, req, resp, done]() {
    Status status = session->Run(opts, *req, resp); 
  });
}

現在我們正式進入到 Master 的業務邏輯,接下來就看看如何進一步處理。

2.1 總體概述

我們先來做一下總體概述。在 Master 上:

  • 首先完成對 FullGraph 的剪枝,生成 ClientGraph。
  • 然後,按照 Worker 維度將 ClientGraph 切分為多個 PartitionGraph。
  • 最後,將 PartitionGraph 列表註冊給各個 Worker(這裡有一個 RPC 操作),並啟動各個 Worker 對 PartitionGraph 列表進行併發執行(這裡有一個 RPC 操作)。

結合程式碼來看如下。首先,Master 會呼叫 FindMasterSession 找到 session_handle 對應的 MasterSession,這之後,邏輯就由 MasterSession 來接管。

MasterSession* Master::FindMasterSession(const string& handle) {
  MasterSession* session = nullptr;
  {
    mutex_lock l(mu_);
    session = gtl::FindPtrOrNull(sessions_, handle);
    if (session != nullptr) {
      session->Ref();
    }
  }
  return session;
}

其次,MasterSession::Run 有兩種呼叫可能,我們這裡選擇 DoRunWithLocalExecution 來分析。

Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
                          MutableRunStepResponseWrapper* resp) {
  UpdateLastAccessTime();
  {
    mutex_lock l(mu_);
    if (closed_) {
      return errors::FailedPrecondition("Session is closed.");
    }
    ++num_running_;
    // Note: all code paths must eventually call MarkRunCompletion()
    // in order to appropriate decrement the num_running_ counter.
  }
  Status status;
  if (!req.partial_run_handle().empty()) {
    status = DoPartialRun(opts, req, resp);
  } else {
    status = DoRunWithLocalExecution(opts, req, resp);
  }
  return status;
}

DoRunWithLocalExecution 會做三個主要操作:

  • StartStep 將呼叫 BuildGraph 來生成 ClientGraph,這裡會進行剪枝。
  • BuildAndRegisterPartitions 將 計算圖按 location 不同切分為多個子圖。
  • RunPartitions 執行子圖。這裡的一個子圖就對應一個 worker,就是對應一個 worker service。
Status MasterSession::DoRunWithLocalExecution(
    CallOptions* opts, const RunStepRequestWrapper& req,
    MutableRunStepResponseWrapper* resp) {

  PerStepState pss;
  pss.start_micros = Env::Default()->NowMicros();
  auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });

  // Prepare.
  BuildGraphOptions bgopts;
  BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
  ReffedClientGraph* rcg = nullptr;
  int64 count;
  // StartStep 將呼叫 BuildGraph 來生成 ClientGraph,這裡會進行剪枝
  TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));

  // Unref "rcg" when out of scope.
  core::ScopedUnref unref(rcg);

  // 對計算圖進行切分
  TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));

  // Keeps the highest 8 bits 0x01: we reserve some bits of the
  // step_id for future use.
  uint64 step_id = NewStepId(rcg->collective_graph_key());

  std::unique_ptr<ProfileHandler> ph;
  FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);

  if (pss.collect_partition_graphs &&
      session_opts_.config.experimental().disable_output_partition_graphs()) {
    return errors::InvalidArgument(
        "RunOptions.output_partition_graphs() is not supported when "
        "disable_output_partition_graphs is true.");
  }

  // 執行計算圖
  Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
                                &cancellation_manager_, false);

  cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
  return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
                        resp->mutable_metadata());
}

我們接下來對 DoRunWithLocalExecution 三個主要操作一一分析。

2.2 建立 & 剪枝

2.2.1 建立計算圖

StartStep 關鍵是建立計算圖並且做剪枝。

Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
                                ReffedClientGraph** out_rcg,
                                int64_t* out_count) {
  const uint64 hash = HashBuildGraphOptions(opts);
  {
    mutex_lock l(mu_);
    RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
    auto iter = m->find(hash);
    if (iter == m->end()) {
      // We have not seen this subgraph before. Build the subgraph and
      // cache it.
      std::unique_ptr<ClientGraph> client_graph;
      // 建立計算圖
      TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
      WorkerCacheInterface* worker_cache = get_worker_cache();
      auto entry = new ReffedClientGraph(
          handle_, opts, std::move(client_graph), session_opts_,
          stats_publisher_factory_, is_partial, worker_cache,
          !should_delete_worker_sessions_);
      iter = m->insert({hash, entry}).first;
    }
    *out_rcg = iter->second;
    (*out_rcg)->Ref();
    *out_count = (*out_rcg)->get_and_increment_execution_count();
  }
  return Status::OK();
}

2.2.2 剪枝

BuildGraph 之中最關鍵的是呼叫 PruneGraph 進行剪枝。

Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
                                       std::unique_ptr<ClientGraph>* out) {
  // Grappler optimization might change the structure of a graph itself, and
  // also it can add/prune functions to/from the library.
  std::unique_ptr<Graph> optimized_graph;
  std::unique_ptr<FunctionLibraryDefinition> optimized_flib;

  Status s = OptimizeGraph(options, *graph_, flib_def_.get(), &optimized_graph,
                           &optimized_flib);
  if (!s.ok()) {
    // Simply copy the original graph and the function library if we couldn't
    // optimize it.
    optimized_graph.reset(new Graph(flib_def_.get()));
    CopyGraph(*graph_, optimized_graph.get());
    optimized_flib.reset(new FunctionLibraryDefinition(*flib_def_));
  }

  subgraph::RewriteGraphMetadata rewrite_metadata;
  if (session_options_ == nullptr ||
      !session_options_->config.graph_options().place_pruned_graph()) {
    TF_RETURN_IF_ERROR( // PruneGraph 會進行剪枝
        PruneGraph(options, optimized_graph.get(), &rewrite_metadata));
  } else {
    // This GraphExecutionState represents a graph that was
    // pruned when this was constructed, so we copy the metadata from
    // a member variable.
    CHECK(rewrite_metadata_);
    rewrite_metadata = *rewrite_metadata_;
  }

  GraphOptimizationPassOptions optimization_options;
  optimization_options.session_options = session_options_;
  optimization_options.graph = &optimized_graph;
  optimization_options.flib_def = optimized_flib.get();
  optimization_options.device_set = device_set_;

  TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
      OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));

  int64_t collective_graph_key = options.collective_graph_key;
  if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
    // BuildGraphOptions does not specify a collective_graph_key.  Check all
    // nodes in the Graph and FunctionLibraryDefinition for collective ops and
    // if found, initialize a collective_graph_key as a hash of the ordered set
    // of instance keys.
    std::set<int32> instance_key_set;
    bool has_collective_v2 = false;
    for (Node* node : optimized_graph->nodes()) {
      if (node->IsCollective()) {
        int32_t instance_key;
        TF_RETURN_IF_ERROR(
            GetNodeAttr(node->attrs(), "instance_key", &instance_key));
        instance_key_set.emplace(instance_key);
      } else if (IsCollectiveV2(node->type_string())) {
        has_collective_v2 = true;
      } else {
        const FunctionDef* fdef = optimized_flib->Find(node->def().op());
        if (fdef != nullptr) {
          for (const NodeDef& ndef : fdef->node_def()) {
            if (ndef.op() == "CollectiveReduce" ||
                ndef.op() == "CollectiveBcastSend" ||
                ndef.op() == "CollectiveBcastRecv" ||
                ndef.op() == "CollectiveGather") {
              int32_t instance_key;
              TF_RETURN_IF_ERROR(
                  GetNodeAttr(ndef, "instance_key", &instance_key));
              instance_key_set.emplace(instance_key);
            } else if (IsCollectiveV2(ndef.op())) {
              has_collective_v2 = true;
            }
          }
        }
      }
    }
    if (!instance_key_set.empty()) {
      uint64 hash = 0x8774aa605c729c72ULL;
      for (int32_t instance_key : instance_key_set) {
        hash = Hash64Combine(instance_key, hash);
      }
      collective_graph_key = hash;
    } else if (has_collective_v2) {
      collective_graph_key = 0x8774aa605c729c72ULL;
    }
  }

  // Make collective execution order deterministic if needed.
  if (options.collective_order != GraphCollectiveOrder::kNone) {
    TF_RETURN_IF_ERROR(
        OrderCollectives(optimized_graph.get(), options.collective_order));
  }

  // Copy the extracted graph in order to make its node ids dense,
  // since the local CostModel used to record its stats is sized by
  // the largest node id.
  std::unique_ptr<ClientGraph> dense_copy(
      new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
                      rewrite_metadata.fetch_types, collective_graph_key));
  CopyGraph(*optimized_graph, &dense_copy->graph);

  metrics::UpdateGraphBuildTime(Env::Default()->NowMicros() - start_time_usecs);
  *out = std::move(dense_copy);
  return Status::OK();
}

2.3 切分註冊

2.2.1 原理

因為單個裝置的計算能力和儲存都不足,所以需要對大型模型進行模型分片,其本質就是把模型和相關計算進行切分之後分配到不同的裝置之上。

TensorFlow的 Placement 機制就是解決模型分片問題,其作用就是標明哪個 operation 放置在哪個裝置之上。Placement 這個名詞或者說機制最早應該是 Google Spanner 提出來的,其提供跨區資料遷移時管理功能,也有一定的負載均衡意義。TF 的 Placement 借鑑了 Google 的思想,其原則是:儘量滿足使用者需求;儘量使用計算更快的裝置;優先考慮近鄰性,避免拷貝;確保分配之後的程式可以執行。

Placement 機制完成之後,每個節點就擁有了Placement資訊,而 Partition 方法就可以根據這些節點的資訊對計算圖進行切分。

2.2.2 配置

BuildAndRegisterPartitions 之中會呼叫 RegisterPartitions 切分註冊,我們首先關注的是這裡如何配置切分。可以看到,其使用 SplitByWorker 做了切分標準。

Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
  // 為切分做配置
  PartitionOptions popts;
  popts.node_to_loc = SplitByWorker; // 被worker切分
  popts.new_name = [this](const string& prefix) {
    mutex_lock l(mu_);
    return strings::StrCat(prefix, "_S", next_node_id_++);
  };
  popts.get_incarnation = [this](const string& name) -> int64 {
    Device* d = devices_->FindDeviceByName(name);
    if (d == nullptr) {
      return PartitionOptions::kIllegalIncarnation;
    } else {
      return d->attributes().incarnation();
    }
  };
  popts.control_flow_added = false; // 控制流
  const bool enable_bfloat16_sendrecv =
      session_opts_.config.graph_options().enable_bfloat16_sendrecv();
  // 是否cast
  popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
    if (e->IsControlEdge()) {
      return DT_FLOAT;
    }
    DataType dtype = BaseType(e->src()->output_type(e->src_output()));
    if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
      return DT_BFLOAT16;
    } else {
      return dtype;
    }
  };
  if (session_opts_.config.graph_options().enable_recv_scheduling()) {
    popts.scheduling_for_recvs = true;
    popts.need_to_record_start_times = true;
  }

  // 切分註冊子圖
  TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));

  return Status::OK();
}

SplitByWorker 方法如下。

static string SplitByWorker(const Node* node) {
  string task;
  string device;
  CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
                                         &device))
      << "node: " << node->name() << " dev: " << node->assigned_device_name();
  return task;
}

BuildAndRegisterPartitions 然後呼叫了 RegisterPartitions,RegisterPartitions 會呼叫 DoBuildPartitions 進行分割槽,呼叫 DoRegisterPartitions 註冊分割槽。

Status MasterSession::ReffedClientGraph::RegisterPartitions(
    PartitionOptions popts) {
  {  // Ensure register once.
    mu_.lock();
    if (client_graph_before_register_) {
      // The `ClientGraph` is no longer needed after partitions are registered.
      // Since it can account for a large amount of memory, we consume it here,
      // and it will be freed after concluding with registration.

      std::unique_ptr<ClientGraph> client_graph;
      std::swap(client_graph_before_register_, client_graph);
      mu_.unlock();
      std::unordered_map<string, GraphDef> graph_defs;
      popts.flib_def = client_graph->flib_def.get();
      
      // 進行分割槽
      Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
      if (s.ok()) {
        // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
        // valid after the call to DoRegisterPartitions begins, so
        // `stats_publisher_` must make a copy if it wants to retain the
        // GraphDef objects.
        std::vector<const GraphDef*> graph_defs_for_publishing;
        graph_defs_for_publishing.reserve(partitions_.size());
        for (const auto& name_def : graph_defs) {
          graph_defs_for_publishing.push_back(&name_def.second);
        }
        
        stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
        
        // 註冊分割槽
        s = DoRegisterPartitions(popts, std::move(graph_defs));
      }
      mu_.lock();
      init_result_ = s;
      init_done_.Notify();
    } else {
      mu_.unlock();
      init_done_.WaitForNotification();
      mu_.lock();
    }
    const Status result = init_result_;
    mu_.unlock();
    return result;
  }
}

2.2.3 切分

DoBuildPartitions 會呼叫 Partition 正式進入切分。

#include "tensorflow/core/graph/graph_partition.h"

Status MasterSession::ReffedClientGraph::DoBuildPartitions(
    PartitionOptions popts, ClientGraph* client_graph,
    std::unordered_map<string, GraphDef>* out_partitions) {
  if (popts.need_to_record_start_times) {
    CostModel cost_model(true);
    cost_model.InitFromGraph(client_graph->graph);
    // TODO(yuanbyu): Use the real cost model.
    // execution_state_->MergeFromGlobal(&cost_model);
    SlackAnalysis sa(&client_graph->graph, &cost_model);
    sa.ComputeAsap(&popts.start_times);
  }

  // Partition the graph.
  return Partition(popts, &client_graph->graph, out_partitions);
}
2.2.3.1 Partition

Partition 的主要邏輯如下:

  • 切分原計算圖,產生多個子圖。
  • 如果跨裝置的節點互相有依賴,則插入 Send 和 Recv 節點對。
  • 如果需要,插入 Control Flow 邊。

具體來說是:

  • 分析原計算圖。補齊控制流邊。
    • 為控制流的分散式執行新增 "程式碼"。只為放在多個裝置上的框架(frames)新增程式碼。新圖是原圖的等價變換,並且具有這樣的特性:它可以隨後被任意分割(低至單個裝置的水平),以便分散式執行。
  • 為每個 operator 的節點/邊構建 Memory/Device 資訊,也是為了切分做準備。
    • TF 希望參與計算的張量被分配到裝置上,參與控制的張量被分配到 Host 之上,所以需要對每個 op 進行分析,確定其在 CPU 或者 GPU 上的版本,也需要確定其輸入和輸出張量的記憶體資訊,比如某些 op 雖然位於 GPU 之上但是依然需要從 CPU 讀取資料,又比如有些資料需要強制放到 CPU 之上因為其對 GPU 不友好。
  • 遍歷圖的節點進行分析和切分,插入 Send/Recv 節點和控制邊,最終得到多個子圖。
    • 從原圖取出一個節點 dst,拿到 dst 的 location 資訊,依據 location 資訊拿到其在 partitions 之中的GraphDef,新增 Node,設定裝置。
    • 將 dst 在原來圖之中的輸入邊分析出來,連同控制邊一起,插入到 inputs 陣列之中。
    • 取出 dst 的一個輸入邊,得到邊的 src 節點,得到 src 節點的圖。
      • 如果 src/dst 在同一個圖之中,則說明是同樣分割槽和可以相容的記憶體型別,則在這個圖裡面把 src,dst 連線起來,遍歷到 dst 下一個邊。
      • 如果 src/dst 不在同一個圖裡面,所以需要通訊,這樣就需要依據 edge, src 等資訊構建通訊 key,依據 key 在 cache 之中查詢 Recv 節點,如果找到了,就把 Recv 節點和 dst 節點連起來,遍歷到 dst 下一個邊。
      • 如果存在控制邊,因為是跨裝置,需要把這種依賴關係跨裝置等價表示出來。所以雖然控制邊不真正傳輸張量,也需要發一個訊息給接受方,這樣接收方才知道有一個依賴關係。所以在src裝置上插入一個 dummy const node,在接收方插入一個 identity 節點來讀取這個 shape 是 0 的 dummy const,還需要把 identity 確定為接收方的控制依賴。
      • 新增 Send 節點和 Recv 節點。
      • 針對控制/資料關係做進一步修復。
        • 對於同一裝置上的傳送/接收節點,它們之間是有資料拷貝操作的,所以新增一個從傳送到接收的控制邊。這樣可以防止非同步 recv kernel 在資料可用之前就被排程出去,從而保證了執行順序。
        • 否則是跨裝置,需要根據資料流來重定向控制邊到真實的 recv 節點。
  • 收尾工作,比如完善子圖的版本資訊,函式庫,和send/recv節點的 Incarnation

比如分割之後,如下:

圖 2 分割計算圖,來自 TensorFlow

插入 Send/Recv 節點之後如下:

圖 3 插入節點,來自 TensorFlow

Partition 程式碼具體如下,進行大幅精簡。

Status Partition(const PartitionOptions& opts, Graph* g,
                 std::unordered_map<string, GraphDef>* partitions) {
  Status status;
  partitions->clear();

  GraphInfo g_info;
  if (!opts.control_flow_added) {
    // 分析原計算圖。補齊控制流邊。
    // 為控制流的分散式執行新增 "程式碼"。只為放在多個裝置上的框架(frames)新增程式碼。新圖是原圖的等價變換,並且具有這樣的特性:它可以隨後被任意分割(低至單個裝置的水平),以便分散式執行。
    status = AddControlFlow(opts, g, &g_info);
    if (!status.ok()) return status;
  }

  // At this point, all the graph mutations have been done. Build memory
  // and device type info for every node and edge in the graph.
  // 為每個operator的節點/邊構建Memory/Device資訊,也是為了切分做準備。
  // TF希望參與計算的張量被分配到裝置上,參與控制的張量被分配到Host之上,所以需要對每個op進行分析,確定其在CPU或者GPU上的版本,也需要確定其輸入和輸出張量的記憶體資訊,比如某些op雖然位於GPU之上但是依然需要從CPU讀取資料,而有些資料需要強制放到CPU之上因為其對GPU不友好。
  status = BuildMemoryDeviceInfo(*g, &g_info);
  if (!status.ok()) return status;

  string dstp;
  std::vector<const Edge*> inputs;
  DupRecvTable dup_recv(3);
  //  對於一個節點dst,'ref_recvs'是由ref邊引入到dst的recvs。ref_control_inputs'是由非ref到dst的輸入。
  // 對於(ref_recvs x ref_control_inputs)之中每一個pair,我們增加一個控制邊
  std::vector<NodeDef*> ref_recvs;
  std::vector<string> ref_control_inputs;

  int32_t num_data = 0;
  int32_t num_control = 0;
  for (const Node* dst : g->op_nodes()) { // 遍歷圖的節點進行分析和切分,插入Send/Recv節點和控制邊
    // 從原圖取出一個節點dst
    dstp = opts.node_to_loc(dst); // 拿到dst的location資訊
    GraphDef* dst_graph = &(*partitions)[dstp]; // 依據location資訊拿到其在partitions之中的GraphDef
    NodeDef* dst_def = dst_graph->add_node(); // 新增Node
    *dst_def = dst->def();
    dst_def->set_device(dst->assigned_device_name()); // 設定裝置   
    dst_def->clear_input();  // Inputs are filled below

    // Arrange the incoming edges to dst so that input[i] holds the
    // input flowing into slot numbered i. Trailing entries in input[]
    // hold control edges.
    // 將dst在原來圖之中的輸入邊分析出來,連同控制邊一起,插入到inputs陣列之中。
    inputs.clear();
    inputs.resize(dst->num_inputs(), nullptr);
    ref_recvs.clear();
    ref_control_inputs.clear();
    const Edge* control_flow_edge = nullptr;
    int32_t num_control_flow_edges = 0;
    int32_t num_input_edges = 0;
    for (const Edge* edge : dst->in_edges()) {
      if (edge->IsControlEdge()) {
        if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
          // This is one of the control edges added for control flow. There
          // can be multiple such edges as the dest node may have multiple
          // remote inputs. We keep track of the number of such edges.
          control_flow_edge = edge;
          ++num_control_flow_edges;
        } else {
          inputs.push_back(edge);
        }
      } else {
        DCHECK(inputs[edge->dst_input()] == nullptr);
        inputs[edge->dst_input()] = edge;
        ++num_input_edges;
      }
    }

    // Process in order so that all data edges are added as inputs to
    // dst in Edge::dst_input() order.
    for (const Edge* edge : inputs) { // 取出dst的一個邊
      const Node* src = edge->src(); // 得到邊的src節點
      if (!src->IsOp()) continue;  // Skip Sink/Source nodes.

      GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)]; // 呼叫配置的 SplitByWorker 或者 SplitByDevice 進行分割槽,得到src節點的圖
      if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
        // 在同一個圖之中,則說明是同樣分割槽和可以相容的記憶體型別,則在這個圖裡面把src,dst連線起來
        // Same partition and compatible memory types:
        AddInput(dst_def, src->name(), edge->src_output());
        if (edge->IsControlEdge() ||
            !IsRefType(src->output_type(edge->src_output()))) {
          ref_control_inputs.push_back(src->name());
        }
        continue; // 遍歷到dst下一個邊
      }

      // Check whether there is already a send/recv pair transferring
      // the same tensor/control from the src to dst partition.
      const bool on_host = IsDstInputOnHost(edge, g_info);
      // 因為不在同一個圖裡面,所以需要通訊,這樣就需要依據edge, src等資訊構建通訊key
      DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
      auto iter = dup_recv.find(key); // 依據key在cache之中查詢Recv節點
      if (iter != dup_recv.end()) { // 如果找到了,就把Recv節點和dst節點連起來
        // We found one. Reuse the data/control transferred already.
        const string& recv_node_name = iter->second.recv->name();
        if (edge->IsControlEdge()) {
          AddInput(dst_def, recv_node_name, Graph::kControlSlot);
        } else {
          AddInput(dst_def, recv_node_name, 0);
        }
        ref_control_inputs.push_back(recv_node_name);
        continue; // 遍歷到dst下一個邊
      }

      // 新增Send節點和Recv節點
      NodeDefBuilder::NodeOut send_from; // 設定傳送節點資訊
      if (edge->IsControlEdge()) {
        // Insert a dummy const node that will generate a tiny
        // data element to be sent from send to recv.
        // 如果存在控制邊,因為是跨裝置,需要把這種依賴關係跨裝置等價表示出來。
        // 所以雖然控制邊不真正傳輸張量,也需要發一個訊息給接受方,這樣接收方才知道有一個依賴關係。所以在src裝置上插入一個dummy const node,在接收方插入一個identity節點來讀取這個shape是0的dummy const,還需要把identity確定為接收方的控制依賴
        NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
        if (!status.ok()) return status;
        AddInput(dummy, src->name(), Graph::kControlSlot);
        send_from.Reset(dummy->name(), 0, DT_FLOAT);
      } else {
        send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
      }

      // Need to split edge by placing matching send/recv nodes on
      // the src/dst sides of the edge.
      NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
                              send_start_time, &status);
      if (!status.ok()) return status;

      NodeDef* real_recv = nullptr;
      NodeDef* recv =
          AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
      if (!status.ok()) return status;

       if (src_graph == dst_graph) {
        // For same device send/recv, add a control edge from send to recv.
        // This prevents the asynchronous recv kernel from being scheduled
        // before the data is available.
        // 對於同一裝置上的傳送/接收節點,它們之間是有資料拷貝操作的,所以新增一個從傳送到接收的控制邊。這樣可以防止非同步recv kernel在資料可用之前就被排程出去,從而保證了執行順序。
        AddInput(real_recv, send->name(), Graph::kControlSlot);
      } else if (control_flow_edge != nullptr) {
        // Redirect control edge to the real recv since this is not the same
        // device send/recv.
        // 否則是跨裝置,需要根據資料流來重定向控制邊到真實的recv節點
        --num_control_flow_edges;
        AddInput(real_recv, control_flow_edge->src()->name(),
                 Graph::kControlSlot);
      }

      if (!edge->IsControlEdge() &&
          IsRefType(src->output_type(edge->src_output()))) {
        // If src is of ref type and the edge is not a control edge, dst has
        // read semantics and therefore we must control the recv.
        ref_recvs.push_back(real_recv);
      } else {
        // Memorize the send/recv pair, only if this is not a "ref" edge.
        // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
        // for now we don't do it.
        dup_recv[key] = {recv, real_recv, recv_start_time};
        ref_control_inputs.push_back(recv->name());
      }

      if (edge->IsControlEdge()) {
        ++num_control;
        AddInput(dst_def, recv->name(), Graph::kControlSlot);
      } else {
        ++num_data;
        AddInput(dst_def, recv->name(), 0);
      }
    }

    // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
    // NOTE(yuanbyu): Adding these control edges should not introduce
    // deadlocks. 'dst' has implicit "read" nodes that, when we split
    // across devices, are made explicit; Retargeting the dependencies
    // to 'dst' to those nodes would not introduce cycles if there isn't
    // one before the transformation.
    // NOTE(yuanbyu): This may impact performance because it defers the
    // execution of recvs until all the other inputs become available.
    AddReadControl(ref_recvs, ref_control_inputs);

    // Add back the control edges for control flow that are not used.
    if (control_flow_edge != nullptr) {
      for (int i = 0; i < num_control_flow_edges; ++i) {
        AddInput(dst_def, control_flow_edge->src()->name(),
                 Graph::kControlSlot);
      }
    }
  }

  // 收尾工作,比如完善子圖的版本資訊,函式庫,和send/recv節點的Incarnation
  const FunctionLibraryDefinition* flib_def = opts.flib_def;
  if (flib_def == nullptr) {
    flib_def = &g->flib_def();
  }

  // Set versions, function library and send/recv incarnation.
  for (auto& it : *partitions) {
    GraphDef* gdef = &it.second;
    *gdef->mutable_versions() = g->versions();
    // Prune unreachable functions from `flib_def` before adding them to `gdef`.
    *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();

    // Traverse the graph to fill every send/recv op's incarnation
    // information.
    SetIncarnation(opts, gdef);
  }

  return Status::OK();
}

Partition 用到的部分函式具體如下。

2.2.3.2 AddDummyConst

如果存在控制邊,因為是跨裝置,需要把這種依賴關係跨裝置等價表示出來。所以雖然控制邊不真正傳輸張量,也需要發一個訊息給接受方,這樣接收方才知道有一個依賴關係。

所以在src裝置上插入一個 dummy const node 用來表達這種對下游的控制依賴關係,在接收方插入一個 identity節點來讀取這個 shape 是 0 的 dummy const,還需要把identity確定為接收方的控制依賴。這樣,dummy const node 是生產者,Identity 是消費者角色。就滿足了跨裝置間的通訊需求。

NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
                       const Edge* edge, Status* status) {
  const Node* src = edge->src();
  Tensor tensor(DT_FLOAT, TensorShape({0}));
  NodeDef* result = gdef->add_node();
  *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
                .Device(src->assigned_device_name())
                .Attr("dtype", DT_FLOAT)
                .Attr("value", tensor)
                .Finalize(result, /*consume=*/true);
  return result;
}
2.2.3.3 AddSend

如果 src 和 dst 分別屬於兩個 Partition,則需要把原來兩者之間的普通邊切分開,在它們中間增加 Send 與 Recv 節點,這樣就可以將其劃歸在兩個不同 Partition 之內。

NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
                 GraphDef* gdef, const Edge* edge,
                 NodeDefBuilder::NodeOut send_from, int64_t start_time,
                 Status* status) {
  const DataType dtype = send_from.data_type;
  const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
  const Node* src = edge->src();
  const int src_port = edge->src_output();

  // host_memory = true iff we need to use HostSend/HostCast.
  bool host_memory = false;
  if (!edge->IsControlEdge()) {
    auto src_it = g_info.output_types.find({src->id(), src_port});
    host_memory = (src_it->second == HOST_MEMORY);
  }

  // Add a cast node that casts dtype to cast_dtype.
  // NOTE(yuanbyu): Only cast for cross-device send/recv.
  if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
    const string cast_op = (host_memory) ? "_HostCast" : "Cast";
    NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
                                NodeDebugInfo(*src));
    cast_builder.Device(src->assigned_device_name()).Input(send_from);
    cast_builder.Attr("DstT", cast_dtype);

    if (cast_dtype == DT_BFLOAT16) {
      // the below attribute specifies that the cast to bfloat16 should use
      // truncation. This is needed to retain legacy behavior when we change
      // the default bfloat16 casts to use rounding instead of truncation
      cast_builder.Attr("Truncate", true);
    }

    NodeDef* cast = gdef->add_node();
    *status = cast_builder.Finalize(cast, /*consume=*/true);
    if (!status->ok()) return nullptr;

    // Connect the Send op to the cast.
    send_from.Reset(cast->name(), 0, cast_dtype);
  }

  // Add the send node.
  const string send_op = (host_memory) ? "_HostSend" : "_Send";
  NodeDefBuilder send_builder(opts.new_name(src->name()), send_op,
                              NodeDebugInfo(*src));
  SetSendRecvAttrs(opts, edge, &send_builder);
  send_builder.Device(src->assigned_device_name()).Input(send_from);

  NodeDef* send = gdef->add_node();
  *status = send_builder.Finalize(send, /*consume=*/true);
  return send;
}
2.2.3.4 AddRecv

前面提到的在接收方插入一個 identity 節點來讀取這個 shape 是 0 的 dummy const,還需要把 identity 確定為接收方的控制依賴,這部分程式碼在此實現。Identity 是恆等變化,可以直接輸出張量,這樣既去除了變數的引用標識,也避免了記憶體拷貝。

NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
                 GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
                 Status* status) {
  const DataType dtype = EdgeType(edge);
  const Node* src = edge->src();
  const Node* dst = edge->dst();
  const int dst_port = edge->dst_input();
  DataType cast_dtype = dtype;

  // NOTE(yuanbyu): Only cast for cross-device send/recv.
  if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
    cast_dtype = opts.should_cast(edge);
  }

  // host_memory = true iff we need to use HostRecv/HostCast.
  // Also log the introduction of the send-recv pair, for performance debugging.
  bool host_memory = false;
  if (!edge->IsControlEdge()) {
    auto dst_it = g_info.input_types.find({dst->id(), dst_port});
    DCHECK(dst_it != g_info.input_types.end());
    host_memory = (dst_it->second == HOST_MEMORY);
    bool src_host_memory = false;
  } else {
    // Log control-edge transfers too, but don't mention memory space since it's
    // irrelevant.
		// 省略log
  }

  // Add the recv node.
  const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
  NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op,
                              NodeDebugInfo(*src));
  SetSendRecvAttrs(opts, edge, &recv_builder);
  recv_builder.Device(dst->assigned_device_name())
      .Attr("tensor_type", cast_dtype);
  NodeDef* recv = gdef->add_node();
  *status = recv_builder.Finalize(recv, /*consume=*/true);
  if (!status->ok()) return nullptr;
  *real_recv = recv;

  // Add the cast node (from cast_dtype to dtype) or an Identity node.
  if (dtype != cast_dtype) {
    const string cast_op = (host_memory) ? "_HostCast" : "Cast";
    NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
                                NodeDebugInfo(*src));
    cast_builder.Attr("DstT", dtype);
    cast_builder.Device(dst->assigned_device_name())
        .Input(recv->name(), 0, cast_dtype);
    NodeDef* cast = gdef->add_node();
    *status = cast_builder.Finalize(cast, /*consume=*/true);
    if (!status->ok()) return nullptr;
    return cast;
  } else if (edge->IsControlEdge()) {
    // An Identity is only needed for control edges.
    // 這裡加入了"Identity"。
    NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity",
                              NodeDebugInfo(*src));
    id_builder.Device(dst->assigned_device_name())
        .Input(recv->name(), 0, cast_dtype);
    NodeDef* id = gdef->add_node();
    *status = id_builder.Finalize(id, /*consume=*/true);
    if (!status->ok()) return nullptr;
    return id;
  } else {
    return recv;
  }
}
2.2.3.5 AddInput

AddInput 為下游節點增加輸入。

// Add an input to dst that comes from the "src_slot" output of the
// node named by "src_name".
void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
  if (src_slot == Graph::kControlSlot) {
    dst->add_input(strings::StrCat("^", src_name));
  } else if (src_slot == 0) {
    dst->add_input(src_name.data(), src_name.size());
  } else {
    dst->add_input(strings::StrCat(src_name, ":", src_slot));
  }
}
2.2.3.6 AddReadControl

AddReadControl 其實是通過 add_input 完成控制。

// Add a control edge from each input to each recv.
void AddReadControl(const std::vector<NodeDef*>& recvs,
                    const std::vector<string>& inputs) {
  for (NodeDef* recv : recvs) {
    for (const string& input : inputs) {
      recv->add_input(strings::StrCat("^", input));
    }
  }
}

2.2.4 註冊

現在分割槽完畢,我們來到了註冊階段。

2.2.4.1 DoRegisterPartitions

DoRegisterPartitions 會設定哪個 worker 負責哪個分割槽,關鍵程式碼是:

  • 呼叫 part->worker = worker_cache_->GetOrCreateWorker(part->name) 來設定每個 part 的 worker。

  • 呼叫 part.worker->RegisterGraphAsync(&c->req, &c->resp, cb) 來註冊圖。

Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
    const PartitionOptions& popts,
    std::unordered_map<string, GraphDef> graph_partitions) {
  partitions_.reserve(graph_partitions.size());
  Status s;
  for (auto& name_def : graph_partitions) {
    partitions_.emplace_back();
    Part* part = &partitions_.back();
    part->name = name_def.first;
    TrackFeedsAndFetches(part, name_def.second, popts);
    part->worker = worker_cache_->GetOrCreateWorker(part->name);
    if (part->worker == nullptr) {
      s = errors::NotFound("worker ", part->name);
      break;
    }
  }
  if (!s.ok()) {
    for (Part& part : partitions_) {
      worker_cache_->ReleaseWorker(part.name, part.worker);
      part.worker = nullptr;
    }
    return s;
  }
  struct Call {
    RegisterGraphRequest req;
    RegisterGraphResponse resp;
    Status status;
  };
  const int num = partitions_.size();
  gtl::InlinedVector<Call, 4> calls(num);
  BlockingCounter done(num);
  for (int i = 0; i < num; ++i) {
    const Part& part = partitions_[i];
    Call* c = &calls[i];
    c->req.set_session_handle(session_handle_);
    c->req.set_create_worker_session_called(!should_deregister_);
    c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
    StripDefaultAttributes(*OpRegistry::Global(),
                           c->req.mutable_graph_def()->mutable_node());
    *c->req.mutable_config_proto() = session_opts_.config;
    *c->req.mutable_graph_options() = session_opts_.config.graph_options();
    *c->req.mutable_debug_options() =
        callable_opts_.run_options().debug_options();
    c->req.set_collective_graph_key(collective_graph_key_);

    auto cb = [c, &done](const Status& s) {
      c->status = s;
      done.DecrementCount();
    };
    part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
  }
  done.Wait();
  for (int i = 0; i < num; ++i) {
    Call* c = &calls[i];
    s.Update(c->status);
    partitions_[i].graph_handle = c->resp.graph_handle();
  }
  return s;
}
2.2.4.2 GrpcRemoteWorker

上面的 part.worker->RegisterGraphAsync 會呼叫到 GrpcRemoteWorker,最終傳送 RegisterGraphRequest 給下游 Worker。

tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc 之中,RegisterGraphAsync 會呼叫 rpc。

void RegisterGraphAsync(const RegisterGraphRequest* request,
                        RegisterGraphResponse* response,
                        StatusCallback done) override {
  IssueRequest(request, response, registergraph_, std::move(done));
}

注意是,除非計算圖節點被重新編排,或者 Master 程序被重啟,否則Master 只會執行一次 RegisterGraph。概念上具體示意如下:

圖 4 註冊圖,來自 TensorFlow

2.4 執行計算圖

既然已經分割槽結束,也註冊到了遠端 Worker 之上,每個worker都擁有自己的子圖,接下來就是執行子圖。

Master 通過呼叫 RunGraph 來在 Worker 之上觸發子圖運算,Worker 會使用 GPU/CPU 運算裝置執行TensorFlow Kernel 運算。在 Worker/裝置之間會依據情況不同採用不同傳輸方式:

  • 本節點 GPU 和 CPU 之間採用 cudaMemcpyAsync。
  • 本節點 GPU 和 GPU 之間採用 peer-to-peer DMA。
  • 在 Worker 之間採用 gRPC(TCP) 和 RDMA (Converged Ethernet)。

圖 5 執行子圖

2.4.1 RunPartitions

RunPartitions 呼叫了 RunPartitionsHelper 執行subgraph。

Status MasterSession::ReffedClientGraph::RunPartitions(
    const MasterEnv* env, int64_t step_id, int64_t execution_count,
    PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
    RunCallableResponse* resp, CancellationManager* cm) {

  // Maps the names of fed tensors to their index in `req`.
  std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
  for (size_t i = 0, end = callable_opts_.feed_size(); i < end; ++i) {
    if (!feeds.insert({callable_opts_.feed(i), i}).second) {
      // MakeCallable will fail if there are two feeds with the same name.
      return errors::Internal("Duplicated feeds in callable: ",
                              callable_opts_.feed(i));
    }
  }

  // Create a wrapped response object to collect the fetched values and
  // rearrange them for the RunCallableResponse.
  RunCallableResponseWrapper wrapped_resp;
  wrapped_resp.resp = resp;

  // 在這裡呼叫執行
  TF_RETURN_IF_ERROR(RunPartitionsHelper(
      feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
      call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));

  // Collects fetches.
  for (const string& fetch : callable_opts_.fetch()) {
    TensorProto* fetch_proto = resp->mutable_fetch()->Add();
    auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
    if (iter == wrapped_resp.fetch_key_to_protos.end()) {
      return errors::Internal("Worker did not return a value for fetch: ",
                              fetch);
    }
    fetch_proto->Swap(&iter->second);
  }
  return Status::OK();
}

2.4.2 RunPartitionsHelper

RunPartitionsHelper執行子圖,具體邏輯是:

  • 為每一個分割槽配置一個 RunManyGraphs::Call,給這個 call 配置 request,response,session handle,graph handle,request id,配置 recv key。
  • 每個 worker 傳送 RunGraphAsync。
    • 一個子圖分配給一個 worker,對應一個 worker service。
    • part.worker 是每個分割槽對應的 WorkerInterface 物件,如果在遠端是 GrpcRemoteWorker 例項,否則是 Worker 例項。
  • 註冊各種 callback,等待 RunGraphAsync 執行結果。
  • 處理執行結果。
template <class FetchListType, class ClientRequestType,
          class ClientResponseType>
Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
    const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
    const FetchListType& fetches, const MasterEnv* env, int64_t step_id,
    int64_t execution_count, PerStepState* pss, CallOptions* call_opts,
    const ClientRequestType& req, ClientResponseType* resp,
    CancellationManager* cm, bool is_last_partial_run) {
  // Collect execution cost stats on a smoothly decreasing frequency.
  ExecutorOpts exec_opts;
  // 省略統計程式碼

  const int num = partitions_.size();
  RunManyGraphs calls(num);

  for (int i = 0; i < num; ++i) {
    // 為每一個分割槽配置一個RunManyGraphs::Call
    const Part& part = partitions_[i];
    RunManyGraphs::Call* c = calls.get(i);
    c->worker_name = &part.name;
    c->req.reset(part.worker->CreateRunGraphRequest()); // 配置request
    c->resp.reset(part.worker->CreateRunGraphResponse()); // 配置response
    if (is_partial_) {
      c->req->set_is_partial(is_partial_);
      c->req->set_is_last_partial_run(is_last_partial_run);
    }
    c->req->set_session_handle(session_handle_); // 配置session handle
    c->req->set_create_worker_session_called(!should_deregister_);
    c->req->set_graph_handle(part.graph_handle); // 配置graph handle
    c->req->set_step_id(step_id);
    *c->req->mutable_exec_opts() = exec_opts;
    c->req->set_store_errors_in_response_body(true);
    c->req->set_request_id(GetUniqueRequestId()); // 配置request id
    // If any feeds are provided, send the feed values together
    // in the RunGraph request.
    // In the partial case, we only want to include feeds provided in the req.
    // In the non-partial case, all feeds in the request are in the part.
    // We keep these as separate paths for now, to ensure we aren't
    // inadvertently slowing down the normal run path.
    if (is_partial_) {
      for (const auto& name_index : feeds) {
        const auto iter = part.feed_key.find(string(name_index.first));
        if (iter == part.feed_key.end()) {
          // The provided feed must be for a different partition.
          continue;
        }
        const string& key = iter->second;
        TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
                                                    name_index.second, key));
      }
      // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
      // For now, we just iterate through partitions to find the matching key.
      for (const string& req_fetch : fetches) {
        for (const auto& key_fetch : part.key_fetch) {
          if (key_fetch.second == req_fetch) {
            c->req->add_recv_key(key_fetch.first); // 配置 recv key
            break;
          }
        }
      }
    } else {
      for (const auto& feed_key : part.feed_key) {
        const string& feed = feed_key.first;
        const string& key = feed_key.second;
        auto iter = feeds.find(feed);
        if (iter == feeds.end()) {
          return errors::Internal("No feed index found for feed: ", feed);
        }
        const int64_t feed_index = iter->second;
        TF_RETURN_IF_ERROR(
            AddSendFromClientRequest(req, c->req.get(), feed_index, key));
      }
      for (const auto& key_fetch : part.key_fetch) {
        const string& key = key_fetch.first;
        c->req->add_recv_key(key); // 配置 recv key
      }
    }
  }

  // Issues RunGraph calls.
  for (int i = 0; i < num; ++i) {
    const Part& part = partitions_[i];
    RunManyGraphs::Call* call = calls.get(i);
    part.worker->RunGraphAsync( // 每個 worker 傳送 RunGraphAsync
        &call->opts, call->req.get(), call->resp.get(),
        std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
  }

  // Waits for the RunGraph calls.
  // 註冊各種callback,等待執行結果
  call_opts->SetCancelCallback([&calls]() {
    calls.StartCancel();
  });
  auto token = cm->get_cancellation_token();
  const bool success =
      cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
  if (!success) {
    calls.StartCancel();
  }
  calls.Wait();
  call_opts->ClearCancelCallback();
  if (success) {
    cm->DeregisterCallback(token);
  } else {
    return errors::Cancelled("Step was cancelled");
  }

  // Collects fetches and metadata.
  // 處理執行結果          
  Status status;
  for (int i = 0; i < num; ++i) {
    const Part& part = partitions_[i];
    MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
    for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
      auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
      if (iter == part.key_fetch.end()) {
        status.Update(errors::Internal("Unexpected fetch key: ",
                                       run_graph_resp->recv_key(j)));
        break;
      }
      const string& fetch = iter->second;
      status.Update(
          resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
      if (!status.ok()) {
        break;
      }
    }
    if (pss->collect_timeline) {
      pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
    }
    if (pss->collect_costs) {
      CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
      for (int j = 0; j < cost_graph->node_size(); ++j) {
        resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
            cost_graph->mutable_node(j));
      }
    }
    if (pss->collect_partition_graphs) {
      protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
          resp->mutable_metadata()->mutable_partition_graphs();
      for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
        partition_graph_defs->Add()->Swap(
            run_graph_resp->mutable_partition_graph(i));
      }
    }
  }
  return status;
}

2.4.3 GrpcRemoteWorker

上面呼叫到了如下程式碼通知遠端 Worker 執行子圖。

part.worker->RunGraphAsync(
    &call->opts, call->req.get(), call->resp.get(),
    std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));

RunGraphAsync 具體定義就是 GrpcRemoteWorker 之中。GrpcRemoteWorker 的每個函式呼叫 IssueRequest() 發起一個非同步 gRPC 呼叫。

void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
                   RunGraphResponse* response, StatusCallback done) override {
  IssueRequest(request, response, rungraph_, std::move(done), call_opts);
}

遠端執行的 GrpcWorkerService 作為守護程序,將會處理傳入的 gRPC 請求。

我們總結 DoRunWithLocalExecution 總體邏輯如下:

圖 6 DoRunWithLocalExecution 總體邏輯

2.5 小結

執行邏輯小結如下,注意這裡有兩個grpc 呼叫,一個是 register,一個是 run。首先呼叫 register 把子圖註冊到遠端 Worker 之上,其次呼叫 run 來讓遠端 Worker 完成子圖計算。

圖 7 Master 動態邏輯 2

我們馬上會去 Worker 來一探究竟。

0xFF 參考

[1]. Abadi M, Agarwal A, Barham P, et al. Tensorflow: Large-scale machine learning on heterogeneous distributed systems[J]. arXiv preprint arXiv:1603.04467, 2016.

[2] TensorFlow的圖切割模組——Graph Partitioner

[3] TensorFlow中的Placement啟發式演算法模組——Placer

[4] TensorFlow中的裝置管理——Device的建立與註冊機制