mxnet實現自己的影象資料集增強方法
阿新 • • 發佈:2019-01-03
深度學習做影象相關的內容時候,資料集增強是常用並且十分有效的手段,可以有效的對口過擬合以及提高模型的準確率,針對不同的問題有時候需要特定的方式對資料進行變換。Mxnet已經內建了一些常用的增強手段,例如randomcrop,mirror,顏色抖動等。
但是,當需要具體的針對性的資料增強的時候,就需要自己寫一個augmenter,對此 mxnet還是比較簡單的和靈活的。
最方便的方式當然是在python層面繼承Augmenter:
直接上程式碼了,過程還是很清晰的,繼承augmenter, 並實現方法即可
程式碼中包括rotate 和noise增強, 基於opencv, 由於noise 增強大部分工作是python層面做的,速度比較慢,有待改進。
呼叫方式是和呼叫mxnet的原有方法是一樣的,比如程式碼檔案是MyAugmentation.py。
CAUTION: 這裡提醒大家一下mxnet裡面用影象增強程式碼裡的一個小坑,(也不算坑吧,就是一個用法問題)。在做resize和crop的時候一般直接用就可以,但是做colorjitter或者類似的對影象的資料進行處理,需要先呼叫mx.image.CastAug()就是資料型別轉換,不然會報錯。
import MyAudmentation taug_list_train=[ mx.image.ForceResizeAug(size=(shape_,shape_)), mx.image.RandomCropAug((shape_,shape_)), mx.image.HorizontalFlipAug(0.5), mx.image.CastAug(), ##################!!!!!!!!caution mx.image.ColorJitterAug(0.0, 0.1, 0.1), mx.image.HueJitterAug(0.5), mx.image.LightingAug(0.1, eigval, eigvec), #####呼叫旋轉增強旋轉30度,0.5的概率 MyAugmentation.RandomRotateAug(30,0.5) ] train_iter = mx.image.ImageIter(batch_size=batch_size, data_shape=shape, label_width=1, aug_list=aug_list_train, shuffle=True, path_root='', path_imglist='/you/path/train.lst' )
MyAugmentation.py
import cv2 import mxnet as mx from mxnet.image import Augmenter import random import numpy as np #######################實現對應的影象處理過程供呼叫 def rotate(src, angle, center=None, scale=1.0): image = src.asnumpy() (h, w) = image.shape[:2] # set the center point as the rotate center by default if center is None: center = (w / 2, h / 2) # opencv to M = cv2.getRotationMatrix2D(center, angle, scale) rotated = cv2.warpAffine(image, M, (w, h)) rotated = mx.nd.array(rotated,dtype=np.uint8) return rotated def SaltAndPepper(src,percet): ###it is a very slow mothed, not recommended to use it Salted=src image=int(percet*src.shape[0]*src.shape[1]) for i in range(image): randX=random.randint(0,src.shape[0]-1) randY=random.randint(0,src.shape[1]-1) if random.randint(0,1)==0: Salted[randX,randY]=0. else: Salted[randX,randY]=255. return Salted #######繼承Augmenter,並實現兩個方法即可 ##################### class RandomRotateAug(Augmenter): """Make randomrotate. Parameters ---------- angel : float or int the max angel to rotate p : the possibility the img be rotated """ def __init__(self, angel, possibility): super(RandomRotateAug, self).__init__(angel=angel) self.maxangel = angel self.p=possibility def __call__(self, src): """Augmenter body""" #return resize_short(src, self.size, self.interp) a = random.random() if a > self.p: return src else: angle=random.randint(-self.maxangel,self.maxangel) return rotate(src,angle) class RandomNoiseAug(Augmenter): """Make randomrotate. Parameters ---------- percet : how much should the img be noised p : the possibility the img be noised """ def __init__(self, percet,possibility): super(RandomNoiseAug, self).__init__(percet=percet) self.percet = percet self.p=possibility def __call__(self, src): """Augmenter body""" #return resize_short(src, self.size, self.interp) a = random.random() if a > self.p: return src else: return SaltAndPepper(src,self.percet)
後續逐漸會有構造customed dataiterator,以及customed operator的介紹,如有錯誤請指正,並請諒解:)