1. 程式人生 > 實用技巧 >MXNet原始碼分析 | Gluon介面分散式訓練流程

MXNet原始碼分析 | Gluon介面分散式訓練流程

既然我們已經知道kvstore負責分散式訓練過程中引數的同步,那麼它究竟是如何應用在訓練中的呢?下面我們將從Gluon.Trainer這個介面入手,逐步分析分散式訓練的梯度交換以及引數同步過程。下面這段程式碼摘自python/mxnet/gluno/trainer.py檔案,相較於原始碼刪除了一些多餘的資訊(如某些判斷、註釋等),以便讓我們更好地專注於通訊過程。

程式碼中的step函式是進行梯度交換以及引數更新的主體,它首先呼叫_init_kvstore去初始化kvstore,然後呼叫_allreduce_grads進行梯度傳輸,最後呼叫_update實現引數更新。

class Trainer(object):
    def step(self, batch_size, ignore_stale_grad=False):
        if not self._kv_initialized:
            self._init_kvstore()
        if self._params_to_init:
            self._init_params()

        self._allreduce_grads()
        self._update(ignore_stale_grad)

首先,_init_kvstore這個函式會通過使用者指定的引數來呼叫model.py中的_create_kvstore來初始化kvstore以及update_kv_store這兩個變數。其中kvstoreKVStore類的一個例項化物件,而update_on_kvstore是一個布林型變數,用來判斷是否在ps端更新引數。換句話說,如果該變數為True,那麼模型引數的更新發生在ps端;否則,模型引數的更新發生在worker端,ps端只做梯度的聚合操作(這種情況下,paramerter server是不是就變成了gradient server?