1. 程式人生 > >mx.metric.EvalMetric和mx.io.DataIter學習筆記

mx.metric.EvalMetric和mx.io.DataIter學習筆記

首先看下metric繼承類class MAE_zz(mx.metric.EvalMetric)程式碼:

class MAE_zz(mx.metric.EvalMetric):

    def __init__(self, name = None):
        self.name = "mae"
        super(MAE_zz, self).__init__('mae') # 呼叫其父類,初始化父類的所有資料成員,並將父類名字命名為“mae”

    def reset(self): # 初始化總數(batch_size*num_task)和準確率
        """Resets the internal evaluation result to initial state."""
self.num_inst = [0]*3 # [batch_size/num(ctx), b../n.., ...] self.sum_metric = [0.0]*3 # 每個任務(網路模型最後的輸出層mx.symbol.Group)的準確率 def update(self, labels, preds): # 每一個batch更新metric,便於mae的輸出,不影響網路的訓練 """Updates the internal evaluation result.""" for i in range(len(labels)
): # i代表每個任務 pred_label = mx.nd.argmax_channel(preds[i]).asnumpy().astype('int32') # 得到pred_label(label的預測值,shape=batch_size/ctx);preds[i]代表第i個任務的預測值(shape=batch_size/ctx乘以第i個任務的分類個數)。 label = labels[i].asnumpy().astype('int32') # 得到label(真實label);labels為多個任務的真實label。 mx.
metric.check_label_shapes(label, pred_label) # 檢查pred_label和label的shape是否一致,不一致則raise ValueError self.sum_metric[i] += (pred_label.flat == label.flat).sum() # 計算一個batch裡,每i個任務下的預測正確數 self.num_inst[i] += len(pred_label.flat) # batch_size * len(i) def get(self): # 得到每個batch下所有ctx的準確率 """Gets the current evaluation result. """ if self.num_inst == 0: return self.name, float('nan') else: return zip(*(('%s-task%d' % (self.name, i), float('nan') if self.num_inst[i] == 0 else self.sum_metric[i] / self.num_inst[i]) for i in range(3))) # 返回每個任務的準確率,並zip為mae-task[i] 和 self.sum_metric[i] / self.num_inst[i] def get_name_value(self): """Returns zipped name and value pairs. """ name, value = self.get() # name=mae-task[i]; value=self.sum_metric[i] / self.num_inst[i] return list(zip(name, value))

其次看下迭代器繼承類class Multi_mnist_iterator(mx.io.DataIter)程式碼:

class Multi_mnist_iterator(mx.io.DataIter):
    '''multi label mnist iterator'''

    def __init__(self, data_iter):
        super(Multi_mnist_iterator, self).__init__()
        self.data_iter = data_iter # 初始化data_iter(DataIter迭代器的資料)
        self.batch_size = self.data_iter.batch_size # 得到DataIter的batch_size

    @property # 呼叫property裝飾器的getter獲取data_iter的provide_data(**見圖1**)
    def provide_data(self):
        return self.data_iter.provide_data

    @property # 呼叫property裝飾器的getter獲取data_iter的provide_label(**見圖1**)
    def provide_label(self):
        provide_label = self.data_iter.provide_label[0]
        # Different labels should be used here for actual application
        return [('softmax1_label',(self.batch_size,)),
                ('softmax2_label',(self.batch_size,)),
                ('softmax3_label',(self.batch_size,))]

    def hard_reset(self):
        self.data_iter.hard_reset()

    def reset(self):
        self.data_iter.reset()

    def next(self):
        batch = self.data_iter.next() # 得到data_iter迭代器例項(**見圖2**)
        label = batch.label[0] # label.shape 為batch * num_task
        data = batch.data[0] # data.shape為batch*num_task*weigh*height
        label0, label1, label2 = label.T # 將label變為num_task * batch 輸入給模型;以便模型分任務計算準確率。
        return mx.io.DataBatch(data=[data], label=[label0, label1, label2],
                               pad=batch.pad, index=batch.index)

圖一:data_iter 迭代器
在這裡插入圖片描述

圖二:batch迭代器例項
在這裡插入圖片描述

@property 的作用

    def __init__(self, fget=None, fset=None, fdel=None, doc=None): # known special case of property.__init__
        """
        property(fget=None, fset=None, fdel=None, doc=None) -> property attribute       
        class C(object):
            @property
            def x(self):
                "I am the 'x' property."
                return self._x
            @x.setter
            def x(self, value):
                self._x = value
            @x.deleter
            def x(self):
                del self._x
   
        # (copied from class doc)
        """
        pass
  • 看上面程式碼可知,把x方法(返回x的值)變為屬性只需要加上@property裝飾器即可,即:C.x() => C.x;
  • 此時@property本身又會建立另外一個裝飾器@x.setter,負責把x方法變成給屬性賦值,即:C.x(value) => C.x=value
  • 此時@property本身又會建立另外一個裝飾器@x.deleter,負責刪除C中的元素,即:del C.x。
  • 此3種屬性對應於property的3種方法getter,setter 和 deleter 。