1. 程式人生 > >mxnet實現自己的影象資料集增強方法

mxnet實現自己的影象資料集增強方法

        深度學習做影象相關的內容時候,資料集增強是常用並且十分有效的手段,可以有效的對口過擬合以及提高模型的準確率,針對不同的問題有時候需要特定的方式對資料進行變換。Mxnet已經內建了一些常用的增強手段,例如randomcrop,mirror,顏色抖動等。

        但是,當需要具體的針對性的資料增強的時候,就需要自己寫一個augmenter,對此 mxnet還是比較簡單的和靈活的。

最方便的方式當然是在python層面繼承Augmenter:

        直接上程式碼了,過程還是很清晰的,繼承augmenter, 並實現方法即可

        程式碼中包括rotate 和noise增強, 基於opencv, 由於noise 增強大部分工作是python層面做的,速度比較慢,有待改進。

        呼叫方式是和呼叫mxnet的原有方法是一樣的,比如程式碼檔案是MyAugmentation.py。

        CAUTION: 這裡提醒大家一下mxnet裡面用影象增強程式碼裡的一個小坑,(也不算坑吧,就是一個用法問題)。在做resize和crop的時候一般直接用就可以,但是做colorjitter或者類似的對影象的資料進行處理,需要先呼叫mx.image.CastAug()就是資料型別轉換,不然會報錯。

import MyAudmentation

taug_list_train=[ 
                mx.image.ForceResizeAug(size=(shape_,shape_)), 
                mx.image.RandomCropAug((shape_,shape_)), 
                mx.image.HorizontalFlipAug(0.5), 
                mx.image.CastAug(),
                ##################!!!!!!!!caution
                mx.image.ColorJitterAug(0.0, 0.1, 0.1),
                mx.image.HueJitterAug(0.5), 
                mx.image.LightingAug(0.1, eigval, eigvec),
                #####呼叫旋轉增強旋轉30度,0.5的概率
                MyAugmentation.RandomRotateAug(30,0.5)
                ]
train_iter = mx.image.ImageIter(batch_size=batch_size,
                                    data_shape=shape,
                                    label_width=1,
                                    aug_list=aug_list_train,
                                    shuffle=True,
                                    path_root='',
                                    path_imglist='/you/path/train.lst'
                                    )



MyAugmentation.py
import cv2
import mxnet as mx
from mxnet.image import  Augmenter
import random
import numpy as np
#######################實現對應的影象處理過程供呼叫
def rotate(src, angle, center=None, scale=1.0):
    image = src.asnumpy()
    (h, w) = image.shape[:2]
    # set the center point as the rotate center by default
    if center is None:
        center = (w / 2, h / 2)
    # opencv to 
    M = cv2.getRotationMatrix2D(center, angle, scale)
    rotated = cv2.warpAffine(image, M, (w, h))
    rotated = mx.nd.array(rotated,dtype=np.uint8)
    
    return rotated

def SaltAndPepper(src,percet):
    ###it is a very slow mothed, not recommended  to use it
    Salted=src
    image=int(percet*src.shape[0]*src.shape[1])
    for i in range(image):
        randX=random.randint(0,src.shape[0]-1)
        randY=random.randint(0,src.shape[1]-1)
        if random.randint(0,1)==0:
            Salted[randX,randY]=0.
        else:
            Salted[randX,randY]=255.
    return Salted



#######繼承Augmenter,並實現兩個方法即可

#####################
class RandomRotateAug(Augmenter):
    """Make randomrotate.
    Parameters
    ----------
    angel : float or int the max angel to rotate
    p : the possibility the img be rotated
    """
    def __init__(self, angel, possibility):
        super(RandomRotateAug, self).__init__(angel=angel)
        self.maxangel = angel
        self.p=possibility
    def __call__(self, src):
        """Augmenter body"""
        #return resize_short(src, self.size, self.interp)
        a = random.random()
        if a > self.p:
            return src
        else:
            angle=random.randint(-self.maxangel,self.maxangel)
            return rotate(src,angle)



class RandomNoiseAug(Augmenter):
    """Make randomrotate.
    Parameters
    ----------
    percet : how much should the img be noised
    p : the possibility the img be noised
    """
    def __init__(self, percet,possibility):
        super(RandomNoiseAug, self).__init__(percet=percet)
        self.percet = percet
        self.p=possibility
    def __call__(self, src):
        """Augmenter body"""
        #return resize_short(src, self.size, self.interp)
        a = random.random()
        if a > self.p:
            return src
        else:
            return SaltAndPepper(src,self.percet)

後續逐漸會有構造customed dataiterator,以及customed operator的介紹,如有錯誤請指正,並請諒解:)