[原始碼解析] 深度學習分散式訓練框架 horovod (15) --- 廣播 & 通知
[原始碼解析] 深度學習分散式訓練框架 horovod (15) --- 廣播 & 通知
目錄- [原始碼解析] 深度學習分散式訓練框架 horovod (15) --- 廣播 & 通知
0x00 摘要
Horovod 是Uber於2017年釋出的一個易於使用的高效能的分散式訓練框架,在業界得到了廣泛應用。
本系列將通過原始碼分析來帶領大家瞭解 Horovod。本文是系列第十五篇,看看horovod 彈性訓練如何廣播和傳送通知。
本系列其他文章連結如下:
[原始碼解析] 深度學習分散式訓練框架 Horovod (1) --- 基礎知識
[原始碼解析] 深度學習分散式訓練框架 horovod (2) --- 從使用者角度切入
[原始碼解析] 深度學習分散式訓練框架 horovod (3) --- Horovodrun背後做了什麼
[原始碼解析] 深度學習分散式訓練框架 horovod (4) --- 網路基礎 & Driver
[原始碼解析] 深度學習分散式訓練框架 horovod (5) --- 融合框架
[原始碼解析] 深度學習分散式訓練框架 horovod (6) --- 後臺執行緒架構
[原始碼解析] 深度學習分散式訓練框架 horovod (7) --- DistributedOptimizer
[原始碼解析] 深度學習分散式訓練框架 horovod (8) --- on spark
[原始碼解析] 深度學習分散式訓練框架 horovod (9) --- 啟動 on spark
[原始碼解析] 深度學習分散式訓練框架 horovod (10) --- run on spark
[原始碼解析] 深度學習分散式訓練框架 horovod (11) --- on spark --- GLOO 方案
[原始碼解析] 深度學習分散式訓練框架 horovod (12) --- 彈性訓練總體架構
[原始碼解析] 深度學習分散式訓練框架 horovod (13) --- 彈性訓練之 Driver
[原始碼解析] 深度學習分散式訓練框架 horovod (14) --- 彈性訓練發現節點 & State
0x01 問題
首先,我們提出一個問題:為什麼彈性訓練 需要有廣播?
答案就是:因為捕獲兩種異常之後,需要廣播到各個worker。
1.1 HorovodInternalError
關於 HorovodInternalError 異常處理,我們看看具體容錯機制,就可以知道緣由:
- 在
hvd.elastic.run
裝飾器捕獲異常; - 如果是
HorovodInternalError
,就恢復到最近一次提交的狀態,此時因為是allreduce等異常,所以所有worker都處於停止狀態; - driver 會根據當前正在執行的節點重新執行一個 rendezvous,以便重新初始化 Horovod context;
- 當新的通訊域構造成功後,rank = 0 的 worker 會將自身的模型廣播給其他 worker;
- 所有worker接著上次停止的迭代步數繼續訓練;
因為需要從 rank 0 廣播變數給其他程序,所以必須有一個廣播機制。
1.2 HostsUpdateInterrupt
關於 HostsUpdateInterrupt 異常處理,我們看看具體原因。
- 當驅動程序通過節點發現指令碼發現一個節點被標記為新增或者移除時,它將傳送一個通知到所有workers,在下一次 state.commit() 或者更輕量的 state.check_host_updates() 被呼叫時,會丟擲一個 HostsUpdateInterrupt 異常。這個異常類似於 HorovodInternalError 異常,但是引數狀態等不會從最近一次commit中恢復,而是從當前實時的引數中恢復。
- check_host_updates 方法 會從
_host_messages
中讀取訊息,積累更新,如其方法中註釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時丟擲異常。 - 具體同步使用
_bcast_object
(然後內部呼叫到了 MPI)。
需要一個廣播機制在每個 worker 之間同步狀態(因為這些worker目前都是在正常訓練,需要有一個東西統一打斷他們的訓練,從而重新組建一個通訊環),目的是讓這些 worker 同時丟擲 HostsUpdateInterrupt 異常。
其次,我們需要回顧下上文的流程圖,本文將對其部分內部流程進行細化。
0x02 廣播機制
我們具體剖析廣播機制如下,因為廣播是和具體框架密切結合,所以我們以tensorflow為例,具體程式碼在horovod/tensorflow/elastic.py 之中。
2.1 廣播實現
在 horovod/tensorflow/elastic.py 之中,就是針對 TF 做的特定實現。其中會依據 TF 的版本做不同處理。
2.1.1 TensorFlowKerasState
以 TensorFlowKerasState 為例,在初始化的時候,因為有廣播物件的需要,比如在 TensorFlowKerasState 之中配置了_bcast_model
用來廣播模型,bcast_object
用來廣播物件,broadcast_variables
用來廣播變數。
而且提供了sync函式負責廣播,可以看出來呼叫了_bcast_model
。
class TensorFlowKerasState(ObjectState):
def __init__(self, model, optimizer=None, backend=None, **kwargs):
if not backend or _executing_eagerly():
# 這裡設定了廣播函式
self._bcast_model = lambda: _broadcast_model(self.model, self.optimizer, backend=self.backend)
bcast_object = broadcast_object
else:
# For TensorFlow v1, we need to reuse the broadcast op to prevent incrementing the uids
# 這裡設定了廣播函式
bcast_op = broadcast_variables(_global_variables(), root_rank=0)
self._bcast_model = lambda: self.backend.get_session().run(bcast_op)
bcast_object = broadcast_object_fn(session=self.backend.get_session())
def sync(self):
self._bcast_model() #廣播模型
self._save_model()
super(TensorFlowKerasState, self).sync()
2.1.2 廣播模型
_broadcast_model 函式會 廣播 模型變數,optimizer變數。
def _broadcast_model(model, optimizer, backend):
if _executing_eagerly():
# TensorFlow 2.0 or TensorFlow eager
broadcast_variables(model.variables, root_rank=0) # 廣播模型變數
broadcast_variables(optimizer.variables(), root_rank=0) # 廣播優化器變數
else:
bcast_op = broadcast_variables(_global_variables(), root_rank=0)
backend.get_session().run(bcast_op)
2.1.3 廣播變數
廣播變數的具體實現 在 horovod/tensorflow/functions.py 之中。broadcast_variables 的作用是從 root rank(即 rank 0)廣播變數到其他的程序。
具體也根據 TF 版本做了區別。
def _make_subgraph(f):
return tf.function(f)
@_cache
def _make_broadcast_group_fn():
if _executing_eagerly():
# Eager mode will parallelize independent control flow
def broadcast_group(variables, root_rank): # 在這裡定義
for var in variables:
var.assign(broadcast(var, root_rank)) # 呼叫MPI函式,這裡都指定了是root_rank
return _make_subgraph(broadcast_group)
else:
# Graph mode requires an Op
def broadcast_group(variables, root_rank): # 在這裡定義
# tf.group()用於創造一個操作,可以將傳入引數的所有操作組合,當這個操作完成後,所有 input 中的所有 ops 都已完成。tf.group()操作沒有輸出。
return tf.group(*[var.assign(broadcast(var, root_rank)) # 這裡呼叫MPI函式
for var in variables])
return broadcast_group
def broadcast_variables(variables, root_rank):
"""Broadcasts variables from root rank to all other processes.
"""
broadcast_group = _make_broadcast_group_fn()
return broadcast_group(variables, root_rank # 在上面定義
2.1.4 廣播物件
廣播物件 的作用是從 root rank(即 rank 0)廣播物件到其他的程序。 廣播物件和廣播變數的區別是:物件需要序列化和反序列化。
def broadcast_object(obj, root_rank=0, session=None, name=None):
"""
Serializes and broadcasts an object from root rank to all other processes.
Arguments:
obj: An object capable of being serialized without losing any context.
root_rank: The rank of the process from which parameters will be
broadcasted to all other processes.
session: Session for TensorFlow v1 compatibility.
name: Optional name to use during broadcast, will default to the class
type.
Returns:
The object that was broadcast from the `root_rank`.
"""
if name is None:
name = type(obj).__name__
def to_numpy(v): # 依據tf版本不同做不同處理
if not _executing_eagerly():
sess = session or ops.get_default_session()
return sess.run(v)
else:
return v.numpy()
if rank() == root_rank:
b = io.BytesIO() # BytesIO實現了在記憶體中讀寫bytes
cloudpickle.dump(obj, b) # 序列化,編碼成一個二進位制檔案
t = tf.convert_to_tensor(bytearray(b.getvalue()), dtype=tf.uint8)
sz = tf.convert_to_tensor([t.shape[0]], dtype=tf.int32) # 張量對應維度的數值
to_numpy(broadcast(sz, root_rank, name + '.sz')) # 廣播維度
else:
sz = tf.convert_to_tensor([0], dtype=tf.int32)
sz = to_numpy(broadcast(sz, root_rank, name + '.sz')) # 接受維度
t = tf.zeros(sz.tolist()[0], dtype=tf.uint8)
t = to_numpy(broadcast(t, root_rank, name + '.t')) # 廣播物件內容
if rank() != root_rank:
buf = io.BytesIO(t.tobytes())
obj = cloudpickle.load(buf) # 反序列化,解碼成原本的物件
return obj
2.1.5 HVD C++
底層會呼叫到 MPI 函式完成廣播功能。
def broadcast(tensor, root_rank, name=None, ignore_name_scope=False):
"""An op which broadcasts the input tensor on root rank to the same input tensor
on all other Horovod processes.
The broadcast operation is keyed by the name of the op. The tensor type and
shape must be the same on all Horovod processes for a given name. The broadcast
will not start until all processes are ready to send and receive the tensor.
Returns:
A tensor of the same shape and type as `tensor`, with the value broadcasted
from root rank.
"""
if name is None and not _executing_eagerly():
name = 'HorovodBroadcast_%s' % _normalize_name(tensor.name)
return MPI_LIB.horovod_broadcast(tensor, name=name, root_rank=root_rank,
ignore_name_scope=ignore_name_scope)
2.1.6 MPI
MPI_BCAST的作用是:從一個序列號為root的程序將一條訊息廣播發送到組內的所有程序, 包括它本身在內。
因為之前指定了root_rank,所以即使所有worker雖然都呼叫了同樣程式碼,也只是會把 root_rank 通訊訊息緩衝區中的訊息拷貝到其他所有程序中去。
void MPIController::Bcast(void* buffer, size_t size, int root_rank,
Communicator communicator) {
MPI_Comm comm = mpi_ctx_.GetMPICommunicator(communicator);
int ret_code = MPI_Bcast(buffer, size, MPI_BYTE, root_rank, comm);
if (ret_code != MPI_SUCCESS) {
throw std::runtime_error(
"MPI_Broadcast failed, see MPI output for details.");
}
}
2.1.7 小結
我們總結一下各個函式:
_bcast_model
用來廣播模型;bcast_object
用來廣播物件;broadcast_variables
用來廣播變數;- 廣播物件和廣播變數的區別是:物件需要序列化和反序列化。
_broadcast_model
就是呼叫了broadcast_variables
完成對模型引數的廣播;broadcast_variables
中呼叫了broadcast_group
,broadcast_group
主要就是利用tf.group()把廣播操作組合起來;
2.2 使用
2.2.1 HorovodInternalError
當捕獲HorovodInternalError時候,會進行廣播同步,目的是當新的通訊域構造成功後,rank = 0 的 worker 會將自身的模型廣播給其他 worker。
def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
notification_manager.init()
notification_manager.register_listener(state)
skip_sync = False
try:
while True:
if not skip_sync:
state.sync() # 這裡會進行廣播同步,就是TensorFlowKerasState.sync
try:
return func(state, *args, **kwargs)
except HorovodInternalError:
state.restore() # 捕獲一場,然後繼續while迴圈
skip_sync = False
except HostsUpdatedInterrupt as e:
skip_sync = e.skip_sync
reset()
state.on_reset()
finally:
notification_manager.remove_listener(state)
return wrapper
具體如下:
Worker rank 0 Worker rank n
+ +
| |
| |
| |
v |
Catch HorovodInternalError |
+ |
| |
| |
| |
sync |
| |
| |
v |
_broadcast_model(model) |
+ |
| |
| |
| |
v |
broadcast_variables(model.variables) |
|
broadcast_variables(optimizer.variables) |
|
+ |
| |
| |
| |
v |
broadcast_group |
+ |
| |
| |
| |
v |
MPI_LIB.horovod_broadcast +-------------------> |
+ |
| |
| |
v v
2.2.2 HostsUpdateInterrupt
廣播物件作用是 在每個 worker 之間同步狀態,目的是讓這些 worker 同時丟擲 HostsUpdateInterrupt 異常。
具體如何使用?
在 WorkerNotificationService . _handle
方法之中,呼叫了 self._manager.handle_hosts_updated(req.timestamp, req.res)
進行通知更新。
WorkerNotificationManager. handle_hosts_updated 方法之中,會呼叫註冊的 state,逐一通知更新。
def handle_hosts_updated(self, timestamp, update_res):
for listener in self._listeners:
listener.on_hosts_updated(timestamp, update_res)
是在 State 的幾個方法中可以看到。
- on_hosts_updated :當有 host 變化時候呼叫,即 向
_host_messages
這個 queue 放入一個訊息; - commit :使用者會定期呼叫此函式,會儲存狀態,檢查 host 更改;
- check_host_updates : 會從
_host_messages
中讀取訊息,積累更新,如方法中註釋所述,會在每個 worker 之間同步狀態,目的是讓這些 worker 同時丟擲異常。具體同步使用_bcast_object
;
check_host_updates程式碼如下:
def check_host_updates(self):
"""Checks that a notification has been sent indicating that hosts can be added or will be removed.
Raises a `HostsUpdatedInterrupt` if such a notification has been received.
"""
# Iterate through the update messages sent from the server. If the update timestamp
# is greater than the last update timestamp, then trigger a HostsUpdatedException.
last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
all_update = HostUpdateResult.no_update
while not self._host_messages.empty():
timestamp, update = self._host_messages.get()
if timestamp > last_updated_timestamp:
last_updated_timestamp = timestamp
all_update |= update
# In order to ensure all workers raise the exception at the same time, we need to sync
# the updated state across all the workers.
# TODO(travis): this should be a max allreduce to account for changes in rank 0
# 這裡會廣播
prev_timestamp, self._last_updated_timestamp, all_update = \
self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))
# At this point, updated state is globally consistent across all ranks.
if self._last_updated_timestamp > prev_timestamp:
raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)
具體如下:
+---------------------------+ +--------------+ +-------------+
|Catch HostsUpdatedInterrupt| | Worker rank 1| |Worker rank n|
+---------+-----------------+ +-------+------+ +----+--------+
| | |
| | |
| | |
v | |
| |
WorkerNotificationService | |
+ | |
| | |
| | |
| | |
v | |
| |
manager.handle_hosts_updated+------------> | |
| |
| |
v |
|
on_hosts_updated |
+ |
| |
| |
| |
check_host_updates |
| |
| |
| |
| |
v |
|
broadcast_object |
+ |
| |
| |
| |
| |
v |
|
MPI_LIB.horovod_broadcast +----> |
+ |
| |
| |
v v
0x03 通知機制
上圖中用到 manager.handle_hosts_updated,manager就是 WorkerNotificationManager。
所以我們順著討論下 WorkerNotificationManager,這是 Hovorod 的通知機制。
3.1 WorkerNotificationManager 生成
每個host 只有一個 WorkerNotificationManager,也只有一個 WorkerNotificationService。
注意:是 ElasticDriver 會作為client,給這些 WorkerNotificationService 發訊息,從而引起 WorkerNotificationManager 的對應操作。
horovod/common/elastic.py 有如下程式碼完成了例項生成。
notification_manager = WorkerNotificationManager()
WorkerNotificationManager 定義如下:
class WorkerNotificationManager(object):
def __init__(self):
self._lock = threading.Lock()
self._service = WorkerNotificationService(secret_key, nic, self)
self._listeners = set()
3.2 初始化
在使用者程式碼啟動之前,會先初始化 WorkerNotificationManager。
def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
# 初始化 WorkerNotificationManager
notification_manager.init()
# 把自己對應的 state 註冊到 notification_manager
notification_manager.register_listener(state)
WorkerNotificationManager初始化程式碼如下,其邏輯是:
- 如果 _service 已經生成,則直接返回,這就保證了每個host之中只有一個 WorkerNotificationService。
- 從系統變數中得到 rendezvous 的各種資訊,比如地址,埠,key 等等;
- 生成 WorkerNotificationService,賦值給 _service;
- 使用 put_data_into_kvstore 把 本 worker 的地址 和 給其在邏輯通訊環分配的序號 rank 傳送給 rendezvous(這個為了後續生成 WorkerNotificationClient 使用)。
- 備註:這個 rendezvous 會儲存每個 worker 的地址和給其在邏輯通訊環分配的序號 rank。worker 程序可以通過這個 rendezvous 來構造新的通訊域。
def init(self, rendezvous_addr=None, rendezvous_port=None,
nic=None, hostname=None, local_rank=None):
with self._lock:
if self._service:
return
# 從系統變數中得到 rendezvous 的各種資訊,比如地址,埠,key 等等
rendezvous_addr = rendezvous_addr or os.environ.get(HOROVOD_GLOO_RENDEZVOUS_ADDR)
rendezvous_port = rendezvous_port if rendezvous_port is not None else \
int(os.environ.get(HOROVOD_GLOO_RENDEZVOUS_PORT))
nic = nic or os.environ.get(HOROVOD_GLOO_IFACE)
hostname = hostname or os.environ.get(HOROVOD_HOSTNAME)
local_rank = local_rank if local_rank is not None else \
int(os.environ.get(HOROVOD_LOCAL_RANK))
secret_key = secret.make_secret_key()
self._service = WorkerNotificationService(secret_key, nic, self)
value = (self._service.addresses(), secret_key)
# 把本worker的地址 和 給其在邏輯通訊環分配的序號 rank 傳送給 rendezvous
put_data_into_kvstore(rendezvous_addr,
rendezvous_port,
PUT_WORKER_ADDRESSES,
self._create_id(hostname, local_rank),
value)
具體 put_data_into_kvstore 如下。
def put_data_into_kvstore(addr, port, scope, key, value):
try:
url = "http://{addr}:{port}/{scope}/{key}".format(
addr=addr, port=str(port), scope=scope, key=key
)
req = Request(url, data=codec.dumps_base64(value, to_ascii=False))
req.get_method = lambda: "PUT" # for urllib2 compatibility
urlopen(req)
except (HTTPError, URLError) as e:
raise RuntimeError("Put data input KVStore server failed.", e)
3.3 註冊State
使用者程式碼啟動之前,還會把自己對應的 state 註冊到 notification_manager。
def run_fn(func, reset):
@functools.wraps(func)
def wrapper(state, *args, **kwargs):
# 初始化 WorkerNotificationManager
notification_manager.init()
# 把自己對應的 state 註冊到 notification_manager
notification_manager.register_listener(state)
具體程式碼如下:
def register_listener(self, listener):
self._listeners.add(listener)
def remove_listener(self, listener):
self._listeners.remove(listener)
3.4 WorkerNotificationService
WorkerNotificationService 在每個host之中也只有一個,用來接受其 client 發來的 HostsUpdatedRequest 訊息,進行處理。可以看到,其繼承了 network.BasicService,這意味著 WorkerNotificationService 本身是一個http server,可以和其client互動,大家可以想想之前介紹的各種 driver / client,就可以理解其機制了。
class WorkerNotificationService(network.BasicService):
NAME = 'worker notification service'
def __init__(self, key, nic, manager):
super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
key,
nic)
self._manager = manager
def _handle(self, req, client_address):
if isinstance(req, HostsUpdatedRequest):
self._manager.handle_hosts_updated(req.timestamp, req.res)
return network.AckResponse()
return super(WorkerNotificationService, self)._handle(req, client_address)
邏輯如下:
+-------------------------------+ +---------------------------+
| WorkerNotificationManager | | rendezvous |
| +------------------------> | |
| | put_data_into_kvstore | |
| | | |
| | +---------------------------+
| _listeners |
| + | +---------------------------+
| | _service +-----------------------------> | WorkerNotificationService |
| | | | |
+-----------------------+-------+ | |
| ^ | |
| | | |
| | | |
| +----------------------------------------+ _manager |
| | |
v | |
+---------------------------+
[State 1, State 2, ......, State n]
3.5 WorkerNotificationClient
WorkerNotificationClient 就是用來給 WorkerNotificationService 傳送訊息的介面。
ElasticDriver 中,會針對每個 worker 生成一個對應的 WorkerNotificationClient,用來進行通知。
class WorkerNotificationClient(network.BasicClient):
def __init__(self, addresses, key, verbose, match_intf=False):
super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
addresses,
key,
verbose,
match_intf=match_intf)
def notify_hosts_updated(self, timestamp, update_res):
self._send(HostsUpdatedRequest(timestamp, update_res))
3.6 生成 Client
3.6.1 註冊時機
回顧一下,在 WorkerNotificationManager 的初始化函式 init 中,會給 rendezvous 傳送put 請求,進行註冊。
註冊資訊就是為了 生成client。
put_data_into_kvstore(rendezvous_addr,
rendezvous_port,
PUT_WORKER_ADDRESSES,
self._create_id(hostname, local_rank),
value)
3.6.2 註冊 worker
在 ElasticRendezvousHandler 中有 _put_value,用來處理 PUT_WORKER_ADDRESSES。呼叫 driver 處理。
# 注意,這裡在 Rendezvous Server 之內
def _put_value(self, scope, key, value):
if scope == PUT_WORKER_ADDRESSES:
host, local_rank = key.split(':')
addresses, secret_key = codec.loads_base64(value)
self._put_worker_addresses(host, int(local_rank), addresses, secret_key)
super(RendezvousHandler, self)._put_value(scope, key, value)
def _put_worker_addresses(self, host, local_rank, addresses, secret_key):
# 這裡呼叫driver進行處理
driver.register_worker_server(host, local_rank, addresses, secret_key)
3.6.3 生成 WorkerNotificationClient
ElasticDriver 中,會針對每個 worker 生成一個對應的 WorkerNotificationClient,用來進行通知。
這裡需要注意:ElasticDriver 就是 WorkerNotificationClient 的使用者,需要通知各個worker時候,就呼叫這些WorkerNotificationClient ,給對應host上的 WorkerNotificationService發訊息,從而引起WorkerNotificationManager做相應處理。
# 這裡是 ElasticDriver 之中
def register_worker_server(self, host, slot, addresses, secret_key):
self._worker_clients[(host, slot)] = WorkerNotificationClient(
addresses, secret_key, self._verbose)
邏輯如下:
+-------------------------------+
| WorkerNotificationManager | +---------------------------+ +----------------------------+
| | | rendezvous | | ElasticRendezvousHandler |
| init +--------------------------------> | +-------> | |
| | 1 put_data_into_kvstore | | | |
| | | | | |
| | +---------------------------+ +------------------+---------+
| _listeners | |
| + | +---------------------------+ |
| | _service +-----------------------------> | WorkerNotificationService | |
| | | | | |
+-----------------------+-------+ | | |
| ^ | | |
| | | | |
| | | | |
| +----------------------------------------+ _manager | |
| | | |
v | | |
+---------------------------+ |
[State 1, State 2, ......, State n] |
|
+-------------------------------------------------------------------------------------------+
| 2 register_worker_server
|
|
v
3 new instance
+-------------------------------+
|ElasticDriver | +----------------------------+ +---------------------------+
| | | WorkerNotificationClient 1 | |WorkerNotificationClient n |
| | | | | |
| | | | | |
| _worker_clients +---------------> | (host 1, slot 1) | ... | (host n, slot n) |
| | | For worker 1 | | For worker n |
| | | | | |
+-------------------------------+ +----------------------------+ +---------------------------+
手機如圖:
3.7 使用
3.7.1 發現更新
ElasticDriver._discovery_thread
之中 如果發現有 host 變化,則呼叫 self._notify_workers_host_changes
來通知。
def _notify_workers_host_changes(self, current_hosts, update_res):
next_host_assignments = {}
if current_hosts.count_available_slots() >= self._min_np:
# Assignments are required to be stable via contract
next_host_assignments, _ = self._get_host_assignments(current_hosts)
if next_host_assignments == self.host_assignments:
# Skip notifying workers when host changes would not result in changes of host assignments
return
coordinator_slot_info = self.get_coordinator_info()
coordinator_client = self.get_worker_client(coordinator_slot_info)
timestamp = _epoch_time_s()
coordinator_client.notify_hosts_updated(timestamp, update_res)
3.7.2 獲取 client
get_worker_client 函式就是獲取 WorkerNotificationClient。就是依據 host,slot 資訊來找到某一個 worker 對應的 client。
def get_worker_client(self, slot_info):
return self._worker_clients.get((slot_info.hostname, slot_info.local_rank))
3.7.3 傳送HostsUpdatedRequest
notify_hosts_updated 的作用是傳送HostsUpdatedRequest
class WorkerNotificationClient(network.BasicClient):
def __init__(self, addresses, key, verbose, match_intf=False):
super(WorkerNotificationClient, self).__init__(WorkerNotificationService.NAME,
addresses,
key,
verbose,
match_intf=match_intf)
def notify_hosts_updated(self, timestamp, update_res):
self._send(HostsUpdatedRequest(timestamp, update_res))
3.7.4 處理 HostsUpdatedRequest
WorkerNotificationService 之中會處理HostsUpdatedRequest,呼叫 WorkerNotificationManager處理。
class WorkerNotificationService(network.BasicService):
NAME = 'worker notification service'
def __init__(self, key, nic, manager):
super(WorkerNotificationService, self).__init__(WorkerNotificationService.NAME,
key,
nic)
self._manager = manager
def _handle(self, req, client_address):
if isinstance(req, HostsUpdatedRequest):
self._manager.handle_hosts_updated(req.timestamp, req.res)
return network.AckResponse()
return super(WorkerNotificationService, self)._handle(req, client_address)
3.7.5 WorkerNotificationManager
所以,當有host 更新時候,WorkerNotificationManager 中的 handle_hosts_updated 如下,最終呼叫到 state 的 on_hosts_updated。
def handle_hosts_updated(self, timestamp, update_res):
for listener in self._listeners: # 遍歷state
listener.on_hosts_updated(timestamp, update_res)
State 的實現如下:
def on_hosts_updated(self, timestamp, update_res):
self._host_messages.put((timestamp, update_res))
邏輯如下圖:
+-----------------------------v
^ thread loop |
| |
+----------------+----------------------+ |
| ElasticDriver._discovery_thread | |
1 _notify_workers_host_changes | | |
| | |
+------------------+ | |
| | | |
| | HostManager.update_available_hosts | |
| | | |
| +-----------------+---------------------+ |
| ^ |
| | |
| | |
| +----------<---------------+ v
v
+---------------------------+ 2 HostsUpdatedRequest +----------------------------+ handle_hosts_updated +----------------------------+
| | | | | |
| WorkerNotificationClient +----------------------> | WorkerNotificationService | +------------------> | WorkerNotificationManager |
| | | | | |
+---------------------------+ +----------------------------+ +------+---------------------+
|
|
| on_hosts_updated
|
v
+-----------------------+
| State | |
| | put |
| v |
| _host_messages |
+-----------------------+
手機如下:
3.7.6 處理更新
在使用者呼叫 commit 的時候,才會呼叫 check_host_updates 檢查更新。
def commit(self):
self.save()
self.check_host_updates()
檢查更新就是看看 _host_messages 有沒有新的訊息,如果發現 host 有變化,就會產生一個 HostsUpdatedInterrupt 異常。
def check_host_updates(self):
# Iterate through the update messages sent from the server. If the update timestamp
# is greater than the last update timestamp, then trigger a HostsUpdatedException.
last_updated_timestamp = prev_timestamp = self._last_updated_timestamp
all_update = HostUpdateResult.no_update
while not self._host_messages.empty():
timestamp, update = self._host_messages.get()
if timestamp > last_updated_timestamp:
last_updated_timestamp = timestamp
all_update |= update
# In order to ensure all workers raise the exception at the same time, we need to sync
# the updated state across all the workers.
# TODO(travis): this should be a max allreduce to account for changes in rank 0
prev_timestamp, self._last_updated_timestamp, all_update = \
self._bcast_object((prev_timestamp, last_updated_timestamp, all_update))
# At this point, updated state is globally consistent across all ranks.
if self._last_updated_timestamp > prev_timestamp:
raise HostsUpdatedInterrupt(all_update == HostUpdateResult.removed)
在worker程序出現 HorvodInternalError 錯誤或者 HostsUpdatedInterrupt 節點增刪時,會捕獲這兩個錯誤,呼叫 reset 來進行容錯處理。於是就把流程前後串聯了起來。
具體如下:
至此,廣播通知機制我們整理完畢,下一篇介紹 worker 如何運作。
0xEE 個人資訊
★★★★★★關於生活和技術的思考★★★★★★
微信公眾賬號:羅西的思考
如果您想及時得到個人撰寫文章的訊息推送,或者想看看個人推薦的技術資料,敬請關注。