mx.metric.EvalMetric和mx.io.DataIter學習筆記
阿新 • • 發佈:2018-12-10
首先看下metric繼承類class MAE_zz(mx.metric.EvalMetric)程式碼:
class MAE_zz(mx.metric.EvalMetric):
def __init__(self, name = None):
self.name = "mae"
super(MAE_zz, self).__init__('mae') # 呼叫其父類,初始化父類的所有資料成員,並將父類名字命名為“mae”
def reset(self): # 初始化總數(batch_size*num_task)和準確率
"""Resets the internal evaluation result to initial state."""
self.num_inst = [0]*3 # [batch_size/num(ctx), b../n.., ...]
self.sum_metric = [0.0]*3 # 每個任務(網路模型最後的輸出層mx.symbol.Group)的準確率
def update(self, labels, preds): # 每一個batch更新metric,便於mae的輸出,不影響網路的訓練
"""Updates the internal evaluation result."""
for i in range(len(labels) ): # i代表每個任務
pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32') # 得到pred_label(label的預測值,shape=batch_size/ctx);preds[i]代表第i個任務的預測值(shape=batch_size/ctx乘以第i個任務的分類個數)。
label = labels[i].asnumpy().astype('int32') # 得到label(真實label);labels為多個任務的真實label。
mx. metric.check_label_shapes(label, pred_label) # 檢查pred_label和label的shape是否一致,不一致則raise ValueError
self.sum_metric[i] += (pred_label.flat == label.flat).sum() # 計算一個batch裡,每i個任務下的預測正確數
self.num_inst[i] += len(pred_label.flat) # batch_size * len(i)
def get(self): # 得到每個batch下所有ctx的準確率
"""Gets the current evaluation result.
"""
if self.num_inst == 0:
return self.name, float('nan')
else:
return zip(*(('%s-task%d' % (self.name, i), float('nan')
if self.num_inst[i] == 0
else self.sum_metric[i] / self.num_inst[i])
for i in range(3))) # 返回每個任務的準確率,並zip為mae-task[i] 和 self.sum_metric[i] / self.num_inst[i]
def get_name_value(self):
"""Returns zipped name and value pairs.
"""
name, value = self.get() # name=mae-task[i]; value=self.sum_metric[i] / self.num_inst[i]
return list(zip(name, value))
其次看下迭代器繼承類class Multi_mnist_iterator(mx.io.DataIter)程式碼:
class Multi_mnist_iterator(mx.io.DataIter):
'''multi label mnist iterator'''
def __init__(self, data_iter):
super(Multi_mnist_iterator, self).__init__()
self.data_iter = data_iter # 初始化data_iter(DataIter迭代器的資料)
self.batch_size = self.data_iter.batch_size # 得到DataIter的batch_size
@property # 呼叫property裝飾器的getter獲取data_iter的provide_data(**見圖1**)
def provide_data(self):
return self.data_iter.provide_data
@property # 呼叫property裝飾器的getter獲取data_iter的provide_label(**見圖1**)
def provide_label(self):
provide_label = self.data_iter.provide_label[0]
# Different labels should be used here for actual application
return [('softmax1_label',(self.batch_size,)),
('softmax2_label',(self.batch_size,)),
('softmax3_label',(self.batch_size,))]
def hard_reset(self):
self.data_iter.hard_reset()
def reset(self):
self.data_iter.reset()
def next(self):
batch = self.data_iter.next() # 得到data_iter迭代器例項(**見圖2**)
label = batch.label[0] # label.shape 為batch * num_task
data = batch.data[0] # data.shape為batch*num_task*weigh*height
label0, label1, label2 = label.T # 將label變為num_task * batch 輸入給模型;以便模型分任務計算準確率。
return mx.io.DataBatch(data=[data], label=[label0, label1, label2],
pad=batch.pad, index=batch.index)
圖一:data_iter 迭代器
圖二:batch迭代器例項
def __init__(self, fget=None, fset=None, fdel=None, doc=None): # known special case of property.__init__
"""
property(fget=None, fset=None, fdel=None, doc=None) -> property attribute
class C(object):
@property
def x(self):
"I am the 'x' property."
return self._x
@x.setter
def x(self, value):
self._x = value
@x.deleter
def x(self):
del self._x
# (copied from class doc)
"""
pass
- 看上面程式碼可知,把x方法(返回x的值)變為屬性只需要加上@property裝飾器即可,即:C.x() => C.x;
- 此時@property本身又會建立另外一個裝飾器@x.setter,負責把x方法變成給屬性賦值,即:C.x(value) => C.x=value
- 此時@property本身又會建立另外一個裝飾器@x.deleter,負責刪除C中的元素,即:del C.x。
- 此3種屬性對應於property的3種方法getter,setter 和 deleter 。