MXNet原始碼分析 | KVStore程序間通訊
本文主要基於MXNet1.6.0版本進行分析。
在上一篇文章中,我們分析了MXNet中KVStore的程序內通訊機制。在這篇文章中,我們主要分析KVStore如何進行多節點分散式通訊。
在KVStore的實現中,KVStoreDist
和KVStoreDistServer
分別對應引數伺服器中的worker節點與server節點。KVStoreDist
繼承自KVStoreLocal
,通過封裝PS-Lite中的KVWorker
實現了Push
和Pull
等介面,從而向server傳送各類請求;而KVStoreDistServer
則封裝了PS-Lite中的KVServer
,用來處理並響應worker發來的各類請求。
worker端執行邏輯
worker建立
KVStoreDist
的建構函式為每個worker節點建立一個ps::KVWorker<char>
型別的物件。如果當前worker節點不是一個recovery的節點,那麼就阻塞到所有的worker和server啟動。
explicit KVStoreDist(bool use_device_comm) : KVStoreLocal(use_device_comm), ps_worker_(nullptr), server_(nullptr) { if (IsWorkerNode()) { int new_customer_id = GetNewCustomerId(); ps_worker_ = new ps::KVWorker<char>(0, new_customer_id); ps::StartAsync(new_customer_id, "mxnet\0"); if (!ps::Postoffice::Get()->is_recovery()) { ps::Postoffice::Get()->Barrier( new_customer_id, ps::kWorkerGroup + ps::kServerGroup + ps::kScheduler); } } bigarray_bound_ = dmlc::GetEnv("MXNET_KVSTORE_BIGARRAY_BOUND", 1000 * 1000); log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false); }
worker的初始化過程
在初始化時,每個worker首先檢查key的唯一性,隨後呼叫comm_->Init
為每個key初始化進行本地通訊的資源。本地初始化完成後,worker0把自己本地的權重發送給所有的server。worker0在其push操作完成後,會將資料寫入到comm_buf_
和compr_buf_
這兩個緩衝區中。
void InitImpl(const std::vector<int>& keys, const std::vector<NDArray>& values) override { CheckUnique(keys); for (size_t i = 0; i < keys.size(); ++i) { comm_->Init(keys[i], values[i].storage_type(), values[i].shape(), values[i].dtype()); } if (get_rank() == 0 && this->ps_worker_->get_customer()->customer_id() == 0) { Push_(keys, values, 0, false); // wait until the push is finished for (const int key : keys) { comm_buf_[key].WaitToWrite(); compr_buf_[key].WaitToWrite(); } } else { // do nothing } if (!ps::Postoffice::Get()->is_recovery()) { Barrier(); } }
worker傳送控制訊息
worker端通過SendCommandToServers
函式向server端傳送控制訊息。例如,在KVStoreDist
的解構函式中有如下程式碼,用來從worker0節點向所有server節點發送一個終止的命令。
if (get_rank() == 0 && ps_worker_->get_customer()->customer_id() == 0) {
// stop the executor at servers
SendCommandToServers(static_cast<int>(CommandType::kStopServer), "");
}
worker傳送資料訊息
worker會呼叫Push_
函式向server傳送資料請求,它的核心邏輯如下所示(省略部分程式碼)。與之前提到的本地通訊類似,在向server節點發送資料之前,會先呼叫GroupPairsPush
把具有相同key的value彙總到一個vector中。對於每個key,先在本地進行一次Reduce操作聚合所有裝置上的梯度,並將結果存放到comm_buf
中。隨後,通過EncodeDefaultKey
把key和value編碼成PS-Lite支援的資料結構,再呼叫PushDefault
把對應的資料傳送出去。
void KVStoreDist::Push_(const std::vector<int>& keys,
const std::vector<NDArray>& values,
int priority,
bool do_merge) {
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray>> grouped_val;
GroupKVPairsPush(keys, values, &uniq_keys, &grouped_val, false);
for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
const auto& vals = grouped_vals[i];
NDArray merged = do_merge ? comm_->Reduce(key, vals, priority) : vals[0];
auto &comm_buf = comm_buf_[key];
if (merged.ctx().dev_mask() == cpu::kDevMask) {
// Start of a push doesn't guarantee that the previous pushes are completed.
// This shouldn't affect training of networks though because training involves
// a sequence of push, pull, then push. This imposes ordering that the
// second push happens after the first pull, and the pull happens after first push.
comm_buf = merged; // avoid memory copy
} else {
if (comm_buf.is_none()) {
comm_buf = NDArray(merged.shape(), pinned_ctx_, true, merged.dtype());
}
CopyFromTo(merged, &comm_buf);
}
const int dtype = merged.dtype();
const int num_bytes = mshadow::mshadow_sizeof(dtype);
PSKV& pskv = EncodeDefaultKey(key, comm_buf.shape().Size(), num_bytes);
PushDefault(key, comm_buf, pskv, priority);
}
}
PushDefault
會呼叫ps_worker_->ZPush
來完成梯度的傳送,梯度傳送以及傳送之前的一些準備操作都被封裝到一個lambda表示式中,這個lambda表示式隨後被壓入到MXNet後端的依賴引擎中等待執行。
void PushDefault(int key, const NDArray &send_buf, const PSKV& pskv, int priority) {
auto push_to_servers =
[this, key, pskv, send_buf](RunContext rctx, Engine::CallbackOnComplete cb) {
const int dtype = send_buf.dtype();
// convert to ps keys
const size_t size = send_buf.shape().Size() * mshadow::mshadow_sizeof(dtype);
char* data = static_cast<char *>(send_buf.data().dptr_);
// do push. false means no delete
ps::SArray<char> vals(data, size, false);
int cmd = GetCommandType(RequestType::kDefaultPushPull, dtype);
CHECK_NOTNULL(ps_worker_)->ZPush(
pskv.keys, vals, pskv.lens,
cmd, [cb]() { cb(); });
};
Engine::Get()->PushAsync(
push_to_servers,
pinned_ctx_,
{send_buf.var()},
{},
FnProperty::kNormal,
priority,
"KVStoreDistDefaultPush");
}
Pull操作的過程如下所示。在準備工作完成後,呼叫ps_server_->ZPull
完成權重的拉取,最後在本地執行Broadcast把從server端拉回的權重廣播到所有裝置上。
void PullImpl(const std::vector<int>& keys,
const std::vector<NDArray*>& values,
int priority, bool ignore_sparse) override {
CHECK(ignore_sparse) << "dist kvstore pull doesn't support ignore_sparse=False";
std::vector<int> uniq_keys;
std::vector<std::vector<NDArray*> > grouped_vals;
GroupKVPairsPull(keys, values, &uniq_keys, &grouped_vals, true);
for (size_t i = 0; i < uniq_keys.size(); ++i) {
int key = uniq_keys[i];
// use the same array for merging to guarantee that pull always happens
// after the previous push on this key
auto& recv_buf = comm_buf_[key];
const auto storage_type = grouped_vals[i][0]->storage_type();
CHECK_EQ(storage_type, kDefaultStorage)
<< "Expected stype of value to be kDefaultStorage";
if (recv_buf.is_none()) {
// it may happen for the first time a no-rank-0 worker pull the weight.
recv_buf = NDArray(grouped_vals[i][0]->shape(), pinned_ctx_,
true, grouped_vals[i][0]->dtype());
}
auto pull_from_servers = [this, key, recv_buf](
RunContext rctx, Engine::CallbackOnComplete cb) {
// convert to ps keys
size_t size = recv_buf.shape().Size();
const int dtype = recv_buf.dtype();
const int num_bytes = mshadow::mshadow_sizeof(dtype);
PSKV& pskv = EncodeDefaultKey(key, size, num_bytes) :
char* data = static_cast<char*> (recv_buf.data().dptr_);
// false means not to delete data when SArray is deleted
auto vals = new ps::SArray<char>(data, size * num_bytes, false);
// issue pull
RequestType mode = RequestType::kDefaultPushPull;
const int cmd = GetCommandType(mode, dtype);
CHECK_NOTNULL(ps_worker_)->ZPull(
pskv.keys, vals, &pskv.lens, cmd, [vals, cb](){ delete vals; cb(); });
};
CHECK_NOTNULL(Engine::Get())->PushAsync(
pull_from_servers,
pinned_ctx_,
{},
{recv_buf.var()},
FnProperty::kNormal,
priority,
"KVStoreDistDefaultStoragePull");
comm_->Broadcast(key, recv_buf, grouped_vals[i], priority);
}
}
server端執行邏輯
server的建立以及啟動
首先在KVStoreDistServer
的建構函式中為ps_server_
繫結處理命令請求的CommandHandle
以及處理資料請求的DataHandleEx
。注意到在繫結CommandHandle
時,ps_server_
被向上轉型成ps::SimpleApp*
型別。這是因為ps::SimpleApp
中實現的set_request_handle
只能接收包含兩個形參的函式物件,而ps::KVServer
繼承了ps::SimpleApp
並且過載了set_request_handle
,使之可以接收包含三個形參的函式物件。這樣一來,就完成了對控制請求和資料請求的分開處理。
KVStoreDistServer() {
using namespace std::placeholders;
ps_server_ = new ps::KVServer<char>(0);
static_cast<ps::SimpleApp*>(ps_server_)->set_request_handle(
std::bind(&KVStoreDistServer::CommandHandle, this, _1, _2));
ps_server_->set_request_handle(
std::bind(&KVStoreDistServer::DataHandleEx, this, _1, _2, _3));
sync_mode_ = false;
gradient_compression_ = std::make_shared<GradientCompression>();
log_verbose_ = dmlc::GetEnv("MXNET_KVSTORE_DIST_ROW_SPARSE_VERBOSE", false);
}
處理控制請求
server接收到worker0發來的命令後,會根據命令的型別,執行不同的操作。例如,當worker發來StopServer的命令後,server就會被停止。相應的命令執行完畢後,server會發送一個響應給worker0。注意這裡負責傳送響應的不是ps::KVWorker<char>
型別的物件,而是ps::SimpleApp
型別的物件。
void CommandHandle(const ps::SimpleData& recved, ps::SimpleApp* app) {
CommandType recved_type = static_cast<CommandType>(recved.head);
switch (recved_type) {
case CommandType::kStopServer:
exec_.Stop();
break;
case CommandType::kSyncMode:
sync_mode_ = true;
break;
case CommandType::kSetGradientCompression:
gradient_compression_->DecodeParams(recved.body);
break;
case CommandType::kSetProfilerParams:
// last char is the type of profiler command
ProcessServerProfilerCommands(static_cast<KVStoreServerProfilerCommand>
(recved.body.back() - '0'),
recved.body);
break;
case CommandType::kSetMultiPrecision:
// uses value 1 for message id from frontend
if (!multi_precision_) {
multi_precision_ = true;
CreateMultiPrecisionCopies();
}
break;
case CommandType::kController:
// this uses value 0 for message id from frontend
// let the main thread to execute ctrl, which is necessary for python
exec_.Exec([this, recved]() {
CHECK(controller_);
controller_(recved.head, recved.body);
});
break;
}
app->Response(recved);
}
處理資料請求
前面提到,DataHandleEx
被註冊為處理資料請求的函式,它會根據資料請求型別去呼叫不同的處理函式。預設情況下會呼叫DataHandleDefalut
,該函式會對worker發來的push和pull請求分開處理。當worker節點push梯度到server時,如果某個key是第一次被push,那麼server會為相應的key申請記憶體空間;否則會根據sync_mode_
的值分別進行處理。在sync_mode_ == true
(即同步訓練模式)的情況下,所有worker上的梯度會被聚合到update_buf_[key].merged
中;而在非同步訓練模式下,server把從某個worker接收的梯度放在update_buf_[key].temp_array
中。隨後,worker發來的push請求資訊會被記錄到update_buf_[key].request
中。待上面的工作完成後,會呼叫ApplyUpdates
函式去更新key對應的模型引數。當worker節點向server節點發送pull請求時,server會直接呼叫DefaultStorageResponse
把server節點最新的模型引數傳送給worker。
void DataHandleDefault(const DataHandleType type, const ps::KVMeta& req_meta,
const ps::KVPairs<char>& req_data, ps::KVServer<char>* server) {
int key = DecodeKey(req_data.keys[0]);
auto& stored = store_[key];
if (req_meta.push) { // push operation
size_t ds[] = {(size_t) req_data.lens[0] / mshadow::mshadow_sizeof(type.dtype)};
mxnet::TShape dshape(ds, ds + 1);
TBlob recv_blob;
MSHADOW_REAL_TYPE_SWITCH(type.dtype, DType, {
recv_blob = TBlob(reinterpret_cast<DType*>(req_data.vals.data()), dshape, cpu::kDevMask);
})
NDArray recved = NDArray(recv_blob, 0);
if (stored.is_none()) { // the first push request
// initialization
stored = NDArray(dshape, Context(), false, type.dtype);
CopyFromTo(recved, &stored, 0);
server->Response(req_meta);
stored.WaitToRead();
} else {
auto& updates = update_buf_[key];
if (sync_mode_ && updates.merged.is_none() {
updates.merged = NDArray(dshape, Context(), false, type.dtype);
}
if (updates.request.empty()) { // the first
if (sync_mode_) {
CopyFromTo(recvd, updates.merged);
} else { // async training
updates.temp_array = recved;
}
} else {
updates.merged += recved;
}
updates.request.push_back(req_meta);
ApplyUpdates(type, key, req_data, &updates, server);
} else { // pull operation
DefaultStorageResponse(type, key, req_meta, req_data, server);
}
}
函式ApplyUpdates
實現了模型權重更新的核心邏輯。如果是非同步訓練模式,或者當前的update_buf中的push請求數量等於worker的數量(意味著server收到了所有worker上的梯度),那麼就會執行引數的更新過程;否則就不進行更新,直接呼叫server->Response
給worker發一個不帶任何資料的響應訊息,表示收到了相關的資料。如果server端設定了更新器updater_
,那麼就會在server端執行更新操作;否則,server只對梯度進行聚合。如下程式碼的7~16行描述了這一過程,更新或聚合的結果會被存放到store_[key]
中。由於update_buf_[key].request
中儲存的請求既有可能是push,也有可能是pushpull(唯獨不可能是pull,因為我們只在req_meta.push==true
時才把req_meta
加入到update_buf_[key].request
中),因此我們還要額外處理pushpull這類請求。對於update_buf_[key].request
中的每個請求,如果該請求req.pull==true
,那麼就呼叫DefaultStorageResponse
把模型權重傳輸給worker。在更新過程完成後,update_buf_[key].request
就會被清空,以等待下一次更新。
inline void ApplyUpdates(const DataHandleType type, const int key,
const ps::KVPairs<char>& req_data, UpdateBuf *update_buf,
ps::KVServer<char>* server) {
if (!sync_mode_ || update_buf->request.size() == (size_t) ps::NumWorkers()) {
// let the main thread to execute updater_, which is necessary for python
auto& stored = store_[key];
auto& update = sync_mode_ ? update_buf->merged : update_buf->temp_array;
if (updater_) { // update_on_kvstore == True
exec_.Exec([this, key, &update, &stored](){
CHECK(updater_);
updater_(key, update, &stored);
});
} else { // update_on_kvstore == False, only support for sync mode
CHECK(sync_mode_) << "Updater needs to be set for async mode";
// if no updater, just copy
CopyFromTo(update_buf->merged, &stored);
}
/**
* Request can be for either push or pushpull
* If pull flag is set, respond immediately with the updated values
* Otherwise, only send the notification
*/
bool has_pull = false;
for (const auto& req : update_buf->request) {
has_pull = has_pull || req.pull;
}
if (has_pull) {
// if there is a pull request, perform WaitToRead() once before DefaultStorageResponse
stored.WaitToRead();
for (const auto& req : update_buf->request) {
if (req.pull) {
DefaultStorageResponse(type, key, req, req_data, server);
}
}
update_buf->request.clear();
} else {
// otherwise, send response directly
for (const auto& req : update_buf->request) {
server->Response(req);
}
update_buf->request.clear();
stored.WaitToRead();
}
} else { // donot perform update operation
update_buf->merged.WaitToRead();
}
}
DefaultStorageResponse
會根據傳入的req_meta
和req_data
這兩個引數針對worker的push請求構建出對應的帶資料的響應訊息。響應是一個ps::KVPairs<char>
型別的物件,其中的資料部分拷貝自store_[key]
。響應物件構建完成後,同樣會呼叫server->Response
將訊息發回對應的worker。
void DefaultStorageResponse(const DataHandleType type,
const int key,
const ps::KVMeta& req_meta,
const ps::KVPairs<char> &req_data,
ps::KVServer<char>* server) {
ps::KVPairs<char> response;
const NDArray& stored = store_[key];
CHECK(!stored.is_none()) << "init " << key << " first";
auto len = stored.shape().Size() * mshadow::mshadow_sizeof(stored.dtype());
response.keys = req_data.keys;
response.lens = {len};
// TODO(mli) try to remove this CopyFrom
response.vals.CopyFrom(static_cast<const char*>(stored.data().dptr_), len);
server->Response(req_meta, response);
}