MXNet中bucket機制註記
阿新 • • 發佈:2017-12-19
.org sse shape 沒有 sta ams origin done org
Preface
之前看API以為bucket
是一個根植於底層操作的接口(MXNet doc
功不可沒 -_-|| )。從LSTM
看過來,接觸到了一些相關的程序,後面再把bucketing_module.py
那部分查看了下,發現bucket只是一個應用層機制,主要的實現存在於module/bucketing_module.py裏面。原理清晰,實現簡潔,在這做個記號。
Code & Comments
先放些相關的鏈接,做個預備。
- MXNet 官方的文檔(\tucao 出個文檔真不容易,還帶時效性...)
- 大神的blog闡述,鞭辟入裏
- 之前關於LSTM的blog
鑒於大神已經在這篇[blog]裏面說得生動透徹了,這裏就能省就省,然後說些大神沒功夫顧及的細節。
另外考慮到MXNet的鏈接經常表現出不靠譜的癥狀(\kuxia),歸結一下1
要使用bucket機制,初始化Module時傳入的symbol應該是一個函數,這個函數在被調用時將被傳入叠代器中的bucket_key參數
。
從調用路徑的順序來走一遍把。
在fit
裏面經過bind
,init
等操作,後面會調用prepare
對預取出的數據(如果有)進行準備:
# module/bucketing_module.py
def prepare(self, data_batch):
"""Prepares a data batch for forward.
Parameters
----------
data_batch : DataBatch
"""
# perform bind if haven‘t done so
assert self.binded and self.params_initialized
bucket_key = data_batch.bucket_key
original_bucket_key = self._curr_bucket_key
data_shapes = data_batch.provide_data
label_shapes = data_batch.provide_label
self .switch_bucket(bucket_key, data_shapes, label_shapes)
# switch back
self.switch_bucket(original_bucket_key, None, None)
顯然,switch_bucket
就是負責進行重新綁定的:
# module/bucketing_module.py
def switch_bucket(self, bucket_key, data_shapes, label_shapes=None):
assert self.binded, ‘call bind before switching bucket‘
if not bucket_key in self._buckets: # check if there is already...
symbol, data_names, label_names = self._sym_gen(bucket_key)
module = Module(symbol, data_names, label_names,
logger=self.logger, context=self._context,
work_load_list=self._work_load_list,
fixed_param_names=self._fixed_param_names,
state_names=self._state_names)
module.bind(data_shapes, label_shapes, self._curr_module.for_training,
self._curr_module.inputs_need_grad,
force_rebind=False, shared_module=self._buckets[self._default_bucket_key])
self._buckets[bucket_key] = module
self._curr_module = self._buckets[bucket_key]
self._curr_bucket_key = bucket_key
邏輯很明白,_curr_module
裏面放了眾多的module,這些module的參數全都指向同一組。如果出入的bucket_key
沒有出現過,就bind一個並放入*_curr_module列表裏面去;如果已經有了(包括剛剛bind出來的),就切換到那個module*上。
Misc
其他有一些相關的材料順帶放在這。
- 上一篇blog裏面推測bucket機制可能會對補齊的那部分進行處理,這一點與
io.py
裏面的DataBatch
中pad
變量有些聯系。在module/base_module.py中,查找pad的引用,發現和io.py裏面的註釋一致,只在prediction的時候進行了使用,訓練的時候被忽視。 exmple/rnn/bucketing
裏面有更高層接口的使用示例。
MXNet中bucket機制註記