[原始碼解析] 深度學習分散式訓練框架 horovod (17) --- 彈性訓練之容錯
[原始碼解析] 深度學習分散式訓練框架 horovod (17) --- 彈性訓練之容錯
目錄0x00 摘要
Horovod 是Uber於2017年釋出的一個易於使用的高效能的分散式訓練框架,在業界得到了廣泛應用。
本系列將通過原始碼分析來帶領大家瞭解 Horovod。本文是系列第十七篇,看看horovod 的容錯機制。
我們依然用問題來引導學習。
問題是:
- 這些異常是 每個 worker 自動發出的嗎?
- 是 worker 們一起丟擲異常嗎?
- 這些異常怎麼通知給 Driver?
我們下面一一分析(為了可以獨立成文,本文部分原理內容與前文相同)。
本系列其他文章連結如下:
[原始碼解析] 深度學習分散式訓練框架 Horovod (1) --- 基礎知識
[原始碼解析] 深度學習分散式訓練框架 horovod (2) --- 從使用者角度切入
[原始碼解析] 深度學習分散式訓練框架 horovod (3) --- Horovodrun背後做了什麼
[原始碼解析] 深度學習分散式訓練框架 horovod (4) --- 網路基礎 & Driver
[原始碼解析] 深度學習分散式訓練框架 horovod (5) --- 融合框架
[原始碼解析] 深度學習分散式訓練框架 horovod (6) --- 後臺執行緒架構
[原始碼解析] 深度學習分散式訓練框架 horovod (7) --- DistributedOptimizer
[原始碼解析] 深度學習分散式訓練框架 horovod (8) --- on spark
[原始碼解析] 深度學習分散式訓練框架 horovod (9) --- 啟動 on spark
[原始碼解析] 深度學習分散式訓練框架 horovod (10) --- run on spark
[原始碼解析] 深度學習分散式訓練框架 horovod (11) --- on spark --- GLOO 方案
[原始碼解析] 深度學習分散式訓練框架 horovod (12) --- 彈性訓練總體架構
[原始碼解析] 深度學習分散式訓練框架 horovod (13) --- 彈性訓練之 Driver
[原始碼解析] 深度學習分散式訓練框架 horovod (14) --- 彈性訓練發現節點 & State
[原始碼解析] 深度學習分散式訓練框架 horovod (15) --- 廣播 & 通知
[原始碼解析] 深度學習分散式訓練框架 horovod (16) --- 彈性訓練之Worker生命週期
0x01總體思路
首先,我們需要注意的是:在某種程度上,容錯和彈性排程互為因果。
- 容錯的意思是,作業不受其中程序數量變化影響。
- 彈性排程時,作業裡的程序數量會隨叢集 workload 情況增減,所以作業必須是容錯的,才能和排程系統配合,實現彈性排程。
其次,在原始碼的文件之中,有如下注釋,我們可以看到容錯具體思路。
The reset process following a ``HorovodInternalError`` (failure) or ``HostsUpdatedInterrupt`` (add/remove request) is as follows:
1. Catch exception within the ``hvd.elastic.run`` decorator.
2. Restore last committed state if ``HorovodInternalError`` was raised.
3. Reinitialize Horovod context performing a new round of rendezvous.
4. Synchronize state among the workers by broadcasting from the new worker-0.
5. Resume training by executing the underlying training function.
During rendezvous, older workers will take priority in being assigned worker-0 status to ensure that the state that is broadcast is up to date.
大致翻譯如下:
對於出錯狀態下,在worker程序出現 HorvodInternalError 錯誤或者 HostsUpdateInterrupt 節點增刪時,會捕獲這兩個錯誤,呼叫 reset 來進行容錯處理:
- 在
hvd.elastic.run
裝飾器捕獲異常; - 如果是
HorovodInternalError
,就恢復到最近一次提交(commit)的狀態; - 重新初始化 Horovod context,然後driver 會根據當前正在執行的節點觸發新的一輪的rendezvous,在rendezvous過程中,舊的worker會被優先被選舉為新的rank-0,因為舊的worker才具有最新的狀態;
- 當新的通訊域構造成功後,rank=0 的 worker 會將自身的模型(狀態)廣播給其他 worker;
- 接著上次停止的迭代步數開始訓練,繼續跑下訓練函式(train)中的程式碼;
我們具體來看看如何處理。
0x02 丟擲異常
2.1 示例程式碼
我們首先回顧下用示例程式碼。
import tensorflow as tf
import horovod.tensorflow as hvd
hvd.init()
@tf.function
def train_one_batch(data, target, allreduce=True):
with tf.GradientTape() as tape:
probs = model(data, training=True)
loss = tf.losses.categorical_crossentropy(target, probs)
if allreduce:
tape = hvd.DistributedGradientTape(tape)
gradients = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(gradients, model.trainable_variables))
....
@hvd.elastic.run # 這裡進行了包裝,所以才能進行彈性訓練
def train(state):
for state.epoch in range(state.epoch, epochs):
for state.batch in range(state.batch, batches_per_epoch):
data, target = get_random_batch()
train_one_batch(data, target)
if state.batch % batches_per_commit == 0:
state.commit()
state.batch = 0
state = hvd.elastic.TensorFlowKerasState(model, optimizer, batch=0, epoch=0)
state.register_reset_callbacks([on_state_reset])
train(state)
最關鍵的就是用介面卡 @hvd.elastic.run 包裝了 train(state),所以我們順著來看。
2.2 HorovodInternalError
從如下程式碼可知 hvd.elastic.run 就是 horovod/tensorflow/elastic.py 之中的 run 函式。
import horovod.tensorflow as hvd
@hvd.elastic.run
因此我們來到了 horovod/tensorflow/elastic.py。
func 就是使用者訓練函式,當執行使用者訓練函數出錯時候,會根據捕獲的異常資訊來進行分析,如果是 ring allreduce 相關,就轉為丟擲異常 HorovodInternalError(e)。
def run(func):
from tensorflow.python.framework.errors_impl import UnknownError
def wrapper(state, *args, **kwargs):
try:
return func(state, *args, **kwargs)
except UnknownError as e:
# 判斷是否是集合通訊相關
if 'HorovodAllreduce' in e.message or \
'HorovodAllgather' in e.message or \
'HorovodBroadcast' in e.message:
raise HorovodInternalError(e)
return run_fn(wrapper, _reset)
2.3 HostsUpdatedInterrupt
從前文我們知道:
當驅動程序通過節點發現指令碼發現一個節點被標記為新增或者移除時,它將傳送一個通知到所有workers,worker 根據通知來進行處理。
具體如下:
- 驅動(後臺發現)程序 獲取 WorkerNotificationClient,然後呼叫 WorkerNotificationClient 來進行通知。就是利用 WorkerNotificationClient 傳送 HostsUpdatedRequest。
- WorkerNotificationService 繼承了 network.BasicService,所以 WorkerNotificationClient 就是作為 WorkerNotificationService 的操作介面,從而給 WorkerNotificationService 傳送 HostsUpdatedRequest。
- WorkerNotificationService 會響應 HostsUpdatedRequest。呼叫 handle_hosts_updated 會逐一通知註冊在WorkerNotificationManager 上的 listener(就是使用者程式碼中的 State)。
- 每一個 worker 有自己對應的 State,都位於
WorkerNotificationManager . _listeners
。 - 每個worker收到通知之後,呼叫 _host_messages 會在state 之中註冊 host 的變化,就是往其 _host_messages 之中放入"host 有變化" 的訊息。
- 在下一次 state.commit() 或者更輕量的 state.check_host_updates() 被呼叫時,state.check_host_updates 會從 _host_messages 中讀取訊息,積累更新,如方法中註釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時丟擲 HostsUpdateInterrupt 異常。具體同步使用 _bcast_object(然後內部呼叫到了 MPI)。
- state.check_host_updates() 會丟擲 HostsUpdateInterrupt 異常。
具體程式碼如下:
在使用者呼叫 commit 的時候,會呼叫 check_host_updates 檢查更新。這裡對使用者程式碼是侵入了,使用者使用到了框架的東西,雖然不知道 Driver,但是用到了框架的其他東西,比如 state。
def commit(self):
self.save()
self.check_host_updates()
檢查更新如下。
如果發現 host 有變化,就會產生一個 HostsUpdatedInterrupt 異常。
def check_host_updates(self):
# Iterate through the update messages sent from the server. If the update timestamp
# is greater than the last update timestamp, then trigger a HostsUpdatedException.
last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
all_update = HostUpdateResult.no_update
while not self._host_messages.empty():
timestamp, update = self._host_messages.get()
if timestamp > last_updated_timestamp:
last_updated_timestamp = timestamp
all_update |= update
prev_timestamp, self._last_updated_timestamp, all_update = \
self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))
# At this point, updated state is globally consistent across all ranks.
if self._last_updated_timestamp > prev_timestamp:
# 丟擲異常
raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)
2.4 總結
因此我們可以回答文初的兩個問題:
- 這些異常是 每個 worker 自動發出的嗎?
- 是的自動丟擲的。
- 當執行 使用者訓練函數出錯時候,會根據捕獲的異常資訊來進行分析,如果是 ring allreduce 相關,就轉為丟擲異常 HorovodInternalError(e)。
- 當如果發現 host 有變化,就會產生一個 HostsUpdatedInterrupt 異常。
- 是 worker 們一起丟擲異常嗎?
- 是一起丟擲。
- 如果訓練出錯,則都會丟擲異常
- 當驅動程序通過節點發現指令碼發現一個節點被標記為新增或者移除時,它將傳送一個通知到 所有workers,在下一次 state.commit() 或者更輕量的 state.check_host_updates() 被呼叫時,會一起丟擲一個 HostsUpdateInterrupt 異常。
丟擲異常的邏輯如下:
+-----------------------------------------------------------------+
| Worker |
| |
| HostsUpdatedInterrupt HorovodInternalError |
| ^ ^ |
| | | |
| | +----------------------------------+ | |
| | | train | | |
| | | | | |
| | | optimizer.apply_gradients +---------+ |
| | | | |
| +-------+ state.commit() |
| | | |
| +----------------------------------+ |
| |
| |
+-----------------------------------------------------------------+
0x03 處理異常
3.1 總體邏輯
總體架構是 在 run_fn 之中。
回憶一下 run_fn 是從哪裡來呼叫的。原來是在 run 之中,就是執行 wrapper。而 wrapper 本身是對使用者訓練函式的包裝。
def run(func):
from tensorflow.python.framework.errors_impl import UnknownError
def wrapper(state, *args, **kwargs):
try:
return func(state, *args, **kwargs)
except UnknownError as e:
if 'HorovodAllreduce' in e.message or \
'HorovodAllgather' in e.message or \
'HorovodBroadcast' in e.message:
raise HorovodInternalError(e)
return run_fn(wrapper, _reset)
大概邏輯如圖:
+----------------------------------------------------------------------------+
| Worker |
| |
| +----------------------------------------------------------------------+ |
| | run_fn | |
| | | |
| | | |
| | | |
| | | |
| | | |
| | | |
| | | |
| | HostsUpdatedInterrupt HorovodInternalError | |
| | ^ ^ | |
| | | | | |
| | | +----------------------------------+ | | |
| | | | train | | | |
| | | | | | | |
| | | | optimizer.apply_gradients +---------+ | |
| | | | | | |
| | +-------+ state.commit() | |
| | | | | |
| | +----------------------------------+ | |
| | | |
| | | |
| | | |
| +----------------------------------------------------------------------+ |
+----------------------------------------------------------------------------+
run_fn邏輯如下:
- 當 HorovodInternalError 產生,就會呼叫 state.restore() 來恢復;
- 當 HostsUpdatedInterrupt 被捕獲,會設定 skip_sync;
- 呼叫 reset(),state.on_reset() 進行重置;
- 當下次迴圈,會根據 skip_sync 決定是否執行 state.sync();
具體程式碼如下:
def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False
try:
while True:
if not skip_sync:
state.sync()
try:
return func(state, *args, **kwargs)
except HorovodInternalError:
state.restore()
skip_sync = False
except HostsUpdatedInterrupt as e:
skip_sync = e.skip_sync
reset()
state.on_reset()
finally:
notification_manager.remove_listener(state)
return wrapper
所以我們拓展邏輯如下:
+------------------------------------------------------------------------------+
| Worker |
| |
| +------------------------------------------------------------------------+ |
| | run_fn | |
| | +----------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v | | |
| | | | |
| | state.sync() | | |
| | + | | |
| | | | | |
| | | | | |
| | v | | |
| | +------------------+---------------+ | | |
| | | train | | | |
| | | | | | |
| | | optimizer.apply_gradients +---------+ | | |
| | | | | | | |
| | +-------+ state.commit() | | | |
| | | | | | | | |
| | | +----------------------------------+ | | | |
| | | | | | |
| | v v | | |
| | HostsUpdatedInterrupt HorovodInternalError | | |
| | + | | |
| | + | | | |
| | | | | | |
| | | v | | |
| | | state.restore() | | |
| | | + | | |
| | | | | | |
| | +------------------+ <------------------+ | | |
| | | | | | |
| | | | | | |
| | v v | | |
| | reset() | | |
| | | | |
| | state.on_reset() | | |
| | | | |
| | + | | |
| | | | | |
| | +-----------------------------------> | |
| | | |
| +------------------------------------------------------------------------+ |
| |
+------------------------------------------------------------------------------+
3.2 恢復
state.restore() 會進行恢復。
在 TensorFlowKerasState 之中,實現了 restore。
def restore(self):
self._load_model()
super(TensorFlowKerasState, self).restore()
具體 restore 就是重新載入模型,具體載入就是利用 TensorFlowKerasState 的 model, optimizer 這兩個成員變數。
def _load_model(self):
if _executing_eagerly():
for var, saved_var in zip(self.model.variables, self._saved_model_state):
var.assign(saved_var)
for var, saved_var in zip(self.optimizer.variables(), self._saved_optimizer_state):
var.assign(saved_var)
else:
self.model.set_weights(self._saved_model_state)
self.optimizer.set_weights(self._saved_optimizer_state)
我們拓展如下:
+---------------------------------------------------------------------------------------------------------+
| Worker |
| |
| +----------------------------------------------------------------------------------------------------+ |
| | run_fn | |
| | +---------------------------------------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v | | |
| | | | |
| | state.sync() | | |
| | + | | |
| | | | | |
| | | | | |
| | v | | |
| | +------------------+---------------+ | | |
| | | train | | | |
| | | | | | |
| | | optimizer.apply_gradients +---------+ | | |
| | | | | | | |
| | +-------+ state.commit() | | | |
| | | | | | | | |
| | | +----------------------------------+ | | | |
| | | | | | |
| | v v | | |
| | HostsUpdatedInterrupt HorovodInternalError | | |
| | + | | |
| | + | | | |
| | | | | | |
| | | v +-------------------------+ | | |
| | | state.restore() +---> | _load_model | | | |
| | | + | | | | |
| | | | | model.set_weights | | | |
| | +------------------+ <------------------+ | optimizer.set_weights | | | |
| | | | | var.assign(saved_var) | | | |
| | | | | | | | |
| | v v +-------------------------+ | | |
| | reset() | | |
| | | | |
| | state.on_reset() | | |
| | | | |
| | + | | |
| | | | | |
| | +----------------------------------------------------------------^ | |
| | | |
| +----------------------------------------------------------------------------------------------------+ |
| |
+---------------------------------------------------------------------------------------------------------+
手機如下:
3.3 重置
以下程式碼會進行重置操作。
reset()
state.on_reset()
3.3.1 reset
具體 reset 函式是:
def _reset():
shutdown()
init()
3.3.2 _HorovodBasics
具體使用了 _HorovodBasics 這裡的函式。
_basics = _HorovodBasics(__file__, 'mpi_lib')
init = _basics.init
shutdown = _basics.shutdown
具體如下,就是重新建立 MPI 相關 context。
def init(self, comm=None):
if comm is None:
comm = []
atexit.register(self.shutdown)
if not isinstance(comm, list):
mpi_built = self.MPI_LIB_CTYPES.horovod_mpi_built()
from mpi4py import MPI
if MPI._sizeof(MPI.Comm) == ctypes.sizeof(ctypes.c_int):
MPI_Comm = ctypes.c_int
else:
MPI_Comm = ctypes.c_void_p
self.MPI_LIB_CTYPES.horovod_init_comm.argtypes = [MPI_Comm]
comm_obj = MPI_Comm.from_address(MPI._addressof(comm))
self.MPI_LIB_CTYPES.horovod_init_comm(comm_obj)
else:
comm_size = len(comm)
self.MPI_LIB_CTYPES.horovod_init(
(ctypes.c_int * comm_size)(*comm), ctypes.c_int(comm_size))
def shutdown(self):
self.MPI_LIB_CTYPES.horovod_shutdown()
3.3.3 on_reset
是執行使用者設定的 reset callback。
def on_reset(self):
self._host_messages = queue.Queue()
self.reset()
for callback in self._reset_callbacks:
callback()
比如使用者設定如下callback:
def on_state_reset():
optimizer.lr.assign(lr * hvd.size())
此時邏輯如下:
+-------------------------------------------------------------------------------------------------------------+
| Worker |
| |
| +--------------------------------------------------------------------------------------------------------+ |
| | run_fn | |
| | +-----------------------------------------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v | | |
| | | | |
| | state.sync() | | |
| | + | | |
| | | | | |
| | | | | |
| | v | | |
| | +------------------+---------------+ | | |
| | | train | | | |
| | | | | | |
| | | optimizer.apply_gradients +---------+ | | |
| | | | | | | |
| | +-------+ state.commit() | | | |
| | | | | | | | |
| | | +----------------------------------+ | | | |
| | | | | | |
| | v v | | |
| | HostsUpdatedInterrupt HorovodInternalError +-------------------------+ | | |
| | + | _load_model | | | |
| | + | | | | | |
| | | | | model.set_weights | | | |
| | | v | optimizer.set_weights | | | |
| | | state.restore() +---> | var.assign(saved_var) | | | |
| | | + | | | | |
| | | | +-------------------------+ | | |
| | +------------------+ <------------------+ | | |
| | | | +-------------------------+ | | |
| | | | | _HorovodBasics | | | |
| | v v | | | | |
| | reset() +-----------------------------> | | | | |
| | +---------------+ | horovod_init | | | |
| | | user callback +<----+ state.on_reset() | | | | |
| | +---------------+ | horovod_init_comm | | | |
| | + | | | | |
| | | +-------------------------+ | | |
| | +------------------------------------------------------------------^ | |
| | | |
| +--------------------------------------------------------------------------------------------------------+ |
| |
+-------------------------------------------------------------------------------------------------------------+
手機如下:
3.3.4 sync
當重置時候,使用者也會進行必要的同步,具體是廣播變數 和 存模型 兩步。
def sync(self):
if self.session is not None:
self.session.run(self._bcast_op)
self._save_model()
super(TensorFlowState, self).sync()
3.3.4.1 廣播
廣播函式在之前初始化時候有設定
self._bcast_op = broadcast_variables(self.variables, root_rank=0)
因此,就是 當新的通訊域構造成功後,rank=0 的 worker 會將自身的模型廣播給其他 worker。
3.3.4.2 存模型
存模型就是呼叫 _eval_fn 來把模型變數轉存到記憶體之中。
def _save_model(self):
self._values = [self._eval_fn(var) for var in self.variables]
_eval_fn 在 之前初始化時候有設定
self._eval_fn = self._to_numpy if _executing_eagerly() else self._eval_var
具體函式是:
def _eval_var(self, var):
return var.eval(self.session)
def _to_numpy(self, var):
return var.numpy()
所以我們的邏輯拓展如下:
+-------------------------------------------------------------------------------------------------------------------+
| Worker |
| |
| +-------------------------------------------------------------------------------------------------------------+ |
| | run_fn | |
| | +----------------------------------------------------------------------+ | |
| | | while True: | | |
| | | | | |
| | v | | |
| | +-------------------------------------------------+ | | |
| | state.sync() +--------> |broadcast_variables(self.variables, root_rank=0) | | | |
| | + | | | | |
| | | | _save_model | | | |
| | | | | | | |
| | v +-------------------------------------------------+ | | |
| | +------------------+---------------+ | | |
| | | train | | | |
| | | | | | |
| | | optimizer.apply_gradients +---------+ | | |
| | | | | | | |
| | +-------+ state.commit() | | | |
| | | | | | | | |
| | | +----------------------------------+ | | | |
| | | | | | |
| | v v | | |
| | HostsUpdatedInterrupt HorovodInternalError +-------------------------+ | | |
| | + | _load_model | | | |
| | + | | | | | |
| | | | | model.set_weights | | | |
| | | v | optimizer.set_weights | | | |
| | | state.restore() +---> | var.assign(saved_var) | | | |
| | | + | | | | |
| | | | +-------------------------+ | | |
| | +------------------+ <------------------+ | | |
| | | | +-------------------------+ | | |
| | | | | _HorovodBasics | | | |
| | v v | | | | |
| | reset() +-----------------------------> | | | | |
| | +---------------+ | horovod_init | | | |
| | | user callback +<----+ state.on_reset() | | | | |
| | +---------------+ | horovod_init_comm | | | |
| | + | | | | |
| | | +-------------------------+ | | |
| | +-----------------------------------------------------------------------^ | |
| | | |
| +-------------------------------------------------------------------------------------------------------------+ |
| |
+-------------------------------------------------------------------------------------------------------------------+
手機如下:
至此,彈性訓練部分分析結束。下面二~三篇文章將為大家介紹K8S相關。
0xFF 參考
ElasticDL呼叫 Horovod 在Kubernetes上實現彈性 AllReduce(一)
kubernetes 培訓_在Kubernetes上使用horovod進行分散式深度學習培訓
在 Kubernetes 上彈性深度學習訓練利器 -- Elastic Training Operator