1. 程式人生 > 實用技巧 >MXNet原始碼分析 | KVStore程序間通訊

MXNet原始碼分析 | KVStore程序間通訊

本文主要基於MXNet1.6.0版本進行分析。

上一篇文章中,我們分析了MXNet中KVStore的程序內通訊機制。在這篇文章中,我們主要分析KVStore如何進行多節點分散式通訊。

在KVStore的實現中,KVStoreDistKVStoreDistServer分別對應引數伺服器中的worker節點與server節點。KVStoreDist繼承自KVStoreLocal,通過封裝PS-Lite中的KVWorker實現了PushPull等介面,從而向server傳送各類請求;而KVStoreDistServer則封裝了PS-Lite中的KVServer,用來處理並響應worker發來的各類請求。

worker端執行邏輯

worker建立

KVStoreDist的建構函式為每個worker節點建立一個ps::KVWorker<char>型別的物件。如果當前worker節點不是一個recovery的節點,那麼就阻塞到所有的worker和server啟動。

explicit KVStoreDist(bool use_device_comm)
    : KVStoreLocal(use_device_comm), ps_worker_(nullptr), server_(nullptr) {
  if (IsWorkerNode()) {
    int new_customer_id = GetNewCustomerId();
    ps_worker_ = new ps::KVWorker<char>(0, new_customer_id);
    ps::StartAsync(new_customer_id, "mxnet\0");
    if (!ps::Postoffice::Get()->is_recovery()) { 
      ps::Postoffice::Get()->Barrier(
        new_customer_id,
        ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler);
    }
  }
  bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000);
  log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
}

worker的初始化過程

在初始化時,每個worker首先檢查key的唯一性,隨後呼叫comm_->Init為每個key初始化進行本地通訊的資源。本地初始化完成後,worker0把自己本地的權重發送給所有的server。worker0在其push操作完成後,會將資料寫入到comm_buf_compr_buf_這兩個緩衝區中。

void InitImpl(const std::vector<int>& keys,
              const std::vector<NDArray>& values) override {
  CheckUnique(keys);
  for (size_t i = 0; i < keys.size(); ++i) {
    comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype());
  }
  if (get_rank() == 0 && this->ps_worker_->get_customer()->customer_id() == 0) {
    Push_(keys, values, 0, false);
    // wait until the push is finished
    for (const int key : keys) {
      comm_buf_[key].WaitToWrite();
      compr_buf_[key].WaitToWrite();
    }
  } else {
    // do nothing
  }
  if (!ps::Postoffice::Get()->is_recovery()) {
    Barrier();
  }
}

worker傳送控制訊息

worker端通過SendCommandToServers函式向server端傳送控制訊息。例如,在KVStoreDist的解構函式中有如下程式碼,用來從worker0節點向所有server節點發送一個終止的命令。

if (get_rank() == 0 && ps_worker_->get_customer()->customer_id() == 0) {
   // stop the executor at servers
  SendCommandToServers(static_cast<int>(CommandType::kStopServer), "");
}

worker傳送資料訊息

worker會呼叫Push_函式向server傳送資料請求,它的核心邏輯如下所示(省略部分程式碼)。與之前提到的本地通訊類似,在向server節點發送資料之前,會先呼叫GroupPairsPush把具有相同key的value彙總到一個vector中。對於每個key,先在本地進行一次Reduce操作聚合所有裝置上的梯度,並將結果存放到comm_buf中。隨後,通過EncodeDefaultKey把key和value編碼成PS-Lite支援的資料結構,再呼叫PushDefault把對應的資料傳送出去。

void KVStoreDist::Push_(const std::vector<int>& keys,
                        const std::vector<NDArray>& values,
                        int priority,
                        bool do_merge) {
  std::vector<int> uniq_keys;
  std::vector<std::vector<NDArray>> grouped_val;
  GroupKVPairsPush(keys, values, &uniq_keys, &grouped_val, false);

  for (size_t i = 0; i < uniq_keys.size(); ++i) {
    int key = uniq_keys[i];
    const auto& vals = grouped_vals[i];
    NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0];

    auto &comm_buf = comm_buf_[key];
    if (merged.ctx().dev_mask() == cpu::kDevMask) {
      // Start of a push doesn't guarantee that the previous pushes are completed.
      // This shouldn't affect training of networks though because training involves
      // a sequence of push, pull, then push. This imposes ordering that the
      // second push happens after the first pull, and the pull happens after first push.
      comm_buf = merged;  // avoid memory copy
    } else {
      if (comm_buf.is_none()) {
        comm_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype());
      }
      CopyFromTo(merged, &comm_buf);
    }
    const int dtype = merged.dtype();
    const int num_bytes = mshadow::mshadow_sizeof(dtype);
    PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), num_bytes);
    PushDefault(key, comm_buf, pskv, priority);
  }
}

PushDefault會呼叫ps_worker_->ZPush來完成梯度的傳送,梯度傳送以及傳送之前的一些準備操作都被封裝到一個lambda表示式中,這個lambda表示式隨後被壓入到MXNet後端的依賴引擎中等待執行。

void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) {
  auto push_to_servers =
      [this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
        const int dtype = send_buf.dtype();
        // convert to ps keys
        const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype);
        char* data = static_cast<char *>(send_buf.data().dptr_);
        // do push. false means no delete
        ps::SArray<char> vals(data, size, false);
        int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
        CHECK_NOTNULL(ps_worker_)->ZPush(
            pskv.keys, vals, pskv.lens,
            cmd, [cb]() { cb(); });
      };
  Engine::Get()->PushAsync(
      push_to_servers,
      pinned_ctx_,
      {send_buf.var()},
      {},
      FnProperty::kNormal,
      priority,
      "KVStoreDistDefaultPush");
}

Pull操作的過程如下所示。在準備工作完成後,呼叫ps_server_->ZPull完成權重的拉取,最後在本地執行Broadcast把從server端拉回的權重廣播到所有裝置上。

void PullImpl(const std::vector<int>& keys,
              const std::vector<NDArray*>& values,
              int priority, bool ignore_sparse) override {
  CHECK(ignore_sparse) << "dist kvstore pull doesn't support ignore_sparse=False";
  std::vector<int> uniq_keys;
  std::vector<std::vector<NDArray*> > grouped_vals;
  GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, true);

  for (size_t i = 0; i < uniq_keys.size(); ++i) {
    int key = uniq_keys[i];
    // use the same array for merging to guarantee that pull always happens
    // after the previous push on this key
    auto& recv_buf = comm_buf_[key];
    const auto storage_type = grouped_vals[i][0]->storage_type();
    CHECK_EQ(storage_type, kDefaultStorage)
             << "Expected stype of value to be kDefaultStorage";
    if (recv_buf.is_none()) {
      // it may happen for the first time a no-rank-0 worker pull the weight.
      recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_,
                         true, grouped_vals[i][0]->dtype());
    }
    auto pull_from_servers = [this, key, recv_buf](
        RunContext rctx, Engine::CallbackOnComplete cb) {
      // convert to ps keys
      size_t size = recv_buf.shape().Size();
      const int dtype = recv_buf.dtype();
      const int num_bytes = mshadow::mshadow_sizeof(dtype);
      PSKV& pskv = EncodeDefaultKey(key, size, num_bytes) :
      char* data = static_cast<char*> (recv_buf.data().dptr_);
      // false means not to delete data when SArray is deleted
      auto vals = new ps::SArray<char>(data, size * num_bytes, false);
      // issue pull
      RequestType mode = RequestType::kDefaultPushPull;
      const int cmd = GetCommandType(mode, dtype);
      CHECK_NOTNULL(ps_worker_)->ZPull(
       pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
    };

    CHECK_NOTNULL(Engine::Get())->PushAsync(
        pull_from_servers,
        pinned_ctx_,
        {},
        {recv_buf.var()},
        FnProperty::kNormal,
        priority,
        "KVStoreDistDefaultStoragePull");

    comm_->Broadcast(key, recv_buf, grouped_vals[i], priority);
  }
}

server端執行邏輯

server的建立以及啟動

首先在KVStoreDistServer的建構函式中為ps_server_繫結處理命令請求的CommandHandle以及處理資料請求的DataHandleEx。注意到在繫結CommandHandle時,ps_server_被向上轉型成ps::SimpleApp*型別。這是因為ps::SimpleApp中實現的set_request_handle只能接收包含兩個形參的函式物件,而ps::KVServer繼承了ps::SimpleApp並且過載了set_request_handle,使之可以接收包含三個形參的函式物件。這樣一來,就完成了對控制請求和資料請求的分開處理。

KVStoreDistServer() {
  using namespace std::placeholders;
  ps_server_ = new ps::KVServer<char>(0);
  static_cast<ps::SimpleApp*>(ps_server_)->set_request_handle(
      std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2));
  ps_server_->set_request_handle(
      std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3));
  sync_mode_ = false;
  gradient_compression_ = std::make_shared<GradientCompression>();
  log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
}

處理控制請求

server接收到worker0發來的命令後,會根據命令的型別,執行不同的操作。例如,當worker發來StopServer的命令後,server就會被停止。相應的命令執行完畢後,server會發送一個響應給worker0。注意這裡負責傳送響應的不是ps::KVWorker<char>型別的物件,而是ps::SimpleApp型別的物件。

void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
  CommandType recved_type = static_cast<CommandType>(recved.head);
  switch (recved_type) {
    case CommandType::kStopServer:
      exec_.Stop();
      break;
    case CommandType::kSyncMode:
      sync_mode_ = true;
      break;
    case CommandType::kSetGradientCompression:
      gradient_compression_->DecodeParams(recved.body);
      break;
    case CommandType::kSetProfilerParams:
      // last char is the type of profiler command
      ProcessServerProfilerCommands(static_cast<KVStoreServerProfilerCommand>
                                                (recved.body.back() - '0'),
                                    recved.body);
      break;
    case CommandType::kSetMultiPrecision:
      // uses value 1 for message id from frontend
      if (!multi_precision_) {
        multi_precision_ = true;
        CreateMultiPrecisionCopies();
      }
      break;
    case CommandType::kController:
      // this uses value 0 for message id from frontend
      // let the main thread to execute ctrl, which is necessary for python
      exec_.Exec([this, recved]() {
          CHECK(controller_);
          controller_(recved.head, recved.body);
        });
      break;
  }
  app->Response(recved);
}

處理資料請求

前面提到,DataHandleEx被註冊為處理資料請求的函式,它會根據資料請求型別去呼叫不同的處理函式。預設情況下會呼叫DataHandleDefalut,該函式會對worker發來的push和pull請求分開處理。當worker節點push梯度到server時,如果某個key是第一次被push,那麼server會為相應的key申請記憶體空間;否則會根據sync_mode_的值分別進行處理。在sync_mode_ == true(即同步訓練模式)的情況下,所有worker上的梯度會被聚合到update_buf_[key].merged中;而在非同步訓練模式下,server把從某個worker接收的梯度放在update_buf_[key].temp_array中。隨後,worker發來的push請求資訊會被記錄到update_buf_[key].request中。待上面的工作完成後,會呼叫ApplyUpdates函式去更新key對應的模型引數。當worker節點向server節點發送pull請求時,server會直接呼叫DefaultStorageResponse把server節點最新的模型引數傳送給worker。

void DataHandleDefault(const DataHandleType type, const ps::KVMeta& req_meta,
                       const ps::KVPairs<char>& req_data, ps::KVServer<char>* server) {
  int key = DecodeKey(req_data.keys[0]);
  auto& stored = store_[key];
  if (req_meta.push) { // push operation
    size_t ds[] = {(size_t) req_data.lens[0] / mshadow::mshadow_sizeof(type.dtype)};
    mxnet::TShape dshape(ds, ds + 1);
    TBlob recv_blob;
    MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
      recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()), dshape, cpu::kDevMask);
    })
    NDArray recved = NDArray(recv_blob, 0);
    if (stored.is_none()) { // the first push request
      // initialization
      stored = NDArray(dshape, Context(), false, type.dtype);
      CopyFromTo(recved, &stored, 0);
      server->Response(req_meta);
      stored.WaitToRead();
    } else {
      auto& updates = update_buf_[key];
      if (sync_mode_ && updates.merged.is_none() {
        updates.merged = NDArray(dshape, Context(), false, type.dtype);
      }
      if (updates.request.empty()) { // the first 
        if (sync_mode_) {
          CopyFromTo(recvd, updates.merged);
        } else { // async training
          updates.temp_array = recved;
        }
      } else {
        updates.merged += recved;
      }
      updates.request.push_back(req_meta);
      ApplyUpdates(type, key, req_data, &updates, server);
  } else { // pull operation
    DefaultStorageResponse(type, key, req_meta, req_data, server);
  }
}

函式ApplyUpdates實現了模型權重更新的核心邏輯。如果是非同步訓練模式,或者當前的update_buf中的push請求數量等於worker的數量(意味著server收到了所有worker上的梯度),那麼就會執行引數的更新過程;否則就不進行更新,直接呼叫server->Response給worker發一個不帶任何資料的響應訊息,表示收到了相關的資料。如果server端設定了更新器updater_,那麼就會在server端執行更新操作;否則,server只對梯度進行聚合。如下程式碼的7~16行描述了這一過程,更新或聚合的結果會被存放到store_[key]中。由於update_buf_[key].request中儲存的請求既有可能是push,也有可能是pushpull(唯獨不可能是pull,因為我們只在req_meta.push==true時才把req_meta加入到update_buf_[key].request中),因此我們還要額外處理pushpull這類請求。對於update_buf_[key].request中的每個請求,如果該請求req.pull==true,那麼就呼叫DefaultStorageResponse把模型權重傳輸給worker。在更新過程完成後,update_buf_[key].request就會被清空,以等待下一次更新。

inline void ApplyUpdates(const DataHandleType type, const int key,
                         const ps::KVPairs<char>& req_data, UpdateBuf *update_buf,
                         ps::KVServer<char>* server) {
  if (!sync_mode_ || update_buf->request.size() == (size_t) ps::NumWorkers()) {
    // let the main thread to execute updater_, which is necessary for python
    auto& stored = store_[key];
    auto& update =  sync_mode_ ? update_buf->merged : update_buf->temp_array;
    if (updater_) { // update_on_kvstore == True
      exec_.Exec([this, key, &update, &stored](){
        CHECK(updater_);
        updater_(key, update, &stored);
      });
    } else { // update_on_kvstore == False, only support for sync mode
      CHECK(sync_mode_) << "Updater needs to be set for async mode";
      // if no updater, just copy
      CopyFromTo(update_buf->merged, &stored);
    }
    /**
     * Request can be for either push or pushpull
     * If pull flag is set, respond immediately with the updated values
     * Otherwise, only send the notification
     */
    bool has_pull = false;
    for (const auto& req : update_buf->request) {
      has_pull = has_pull || req.pull;
    }
    if (has_pull) {
      // if there is a pull request, perform WaitToRead() once before DefaultStorageResponse
      stored.WaitToRead();
      for (const auto& req : update_buf->request) {
        if (req.pull) {
          DefaultStorageResponse(type, key, req, req_data, server);
        }
      }
      update_buf->request.clear();
    } else {
      // otherwise, send response directly
      for (const auto& req : update_buf->request) {
        server->Response(req);
      }
      update_buf->request.clear();
      stored.WaitToRead();
    }
  } else { // donot perform update operation
    update_buf->merged.WaitToRead();
  }
}

DefaultStorageResponse會根據傳入的req_metareq_data這兩個引數針對worker的push請求構建出對應的帶資料的響應訊息。響應是一個ps::KVPairs<char>型別的物件,其中的資料部分拷貝自store_[key]。響應物件構建完成後,同樣會呼叫server->Response將訊息發回對應的worker。

void DefaultStorageResponse(const DataHandleType type,
                            const int key,
                            const ps::KVMeta& req_meta,
                            const ps::KVPairs<char> &req_data,
                            ps::KVServer<char>* server) {
  ps::KVPairs<char> response;
  const NDArray& stored = store_[key];
  CHECK(!stored.is_none()) << "init " << key << " first";

  auto len = stored.shape().Size() * mshadow::mshadow_sizeof(stored.dtype());
  response.keys = req_data.keys;
  response.lens = {len};
  // TODO(mli) try to remove this CopyFrom
  response.vals.CopyFrom(static_cast<const char*>(stored.data().dptr_), len);
  server->Response(req_meta, response);
}