MXNet原始碼分析 | Gluon介面分散式訓練流程
阿新 • • 發佈:2020-09-21
既然我們已經知道kvstore負責分散式訓練過程中引數的同步,那麼它究竟是如何應用在訓練中的呢?下面我們將從Gluon.Trainer
這個介面入手,逐步分析分散式訓練的梯度交換以及引數同步過程。下面這段程式碼摘自python/mxnet/gluno/trainer.py檔案,相較於原始碼刪除了一些多餘的資訊(如某些判斷、註釋等),以便讓我們更好地專注於通訊過程。
程式碼中的step
函式是進行梯度交換以及引數更新的主體,它首先呼叫_init_kvstore
去初始化kvstore,然後呼叫_allreduce_grads
進行梯度傳輸,最後呼叫_update
實現引數更新。
class Trainer(object): def step(self, batch_size, ignore_stale_grad=False): if not self._kv_initialized: self._init_kvstore() if self._params_to_init: self._init_params() self._allreduce_grads() self._update(ignore_stale_grad)
首先,_init_kvstore
這個函式會通過使用者指定的引數來呼叫model.py中的_create_kvstore
來初始化kvstore
以及update_kv_store
這兩個變數。其中kvstore
是KVStore
類的一個例項化物件,而update_on_kvstore
是一個布林型變數,用來判斷是否在ps端更新引數。換句話說,如果該變數為True,那麼模型引數的更新發生在ps端;否則,模型引數的更新發生在worker端,ps端只做梯度的聚合操作(這種情況下,paramerter server是不是就變成了gradient server?