Keras 在fit_generator訓練方式中加入影象random_crop操作
使用Keras作前端寫網路時,由於訓練影象尺寸較大,需要做類似 tf.random_crop 影象裁剪操作。
為此研究了一番Keras下已封裝的API。
Data Augmentation(資料擴充)
Data Aumentation 指使用下面或其他方法增加輸入資料量。我們預設影象資料。
旋轉&反射變換(Rotation/reflection): 隨機旋轉影象一定角度; 改變影象內容的朝向;
翻轉變換(flip): 沿著水平或者垂直方向翻轉影象;
縮放變換(zoom): 按照一定的比例放大或者縮小影象;
平移變換(shift): 在影象平面上對影象以一定方式進行平移;
可以採用隨機或人為定義的方式指定平移範圍和平移步長,沿水平或豎直方向進行平移. 改變影象內容的位置;
尺度變換(scale): 對影象按照指定的尺度因子,進行放大或縮小; 或者參照SIFT特徵提取思想,利用指定的尺度因子對影象濾波構造尺度空間. 改變影象內容的大小或模糊程度;
對比度變換(contrast): 在影象的HSV顏色空間,改變飽和度S和V亮度分量,保持色調H不變. 對每個畫素的S和V分量進行指數運算(指數因子在0.25到4之間),增加光照變化;
噪聲擾動(noise): 對影象的每個畫素RGB進行隨機擾動,常用的噪聲模式是椒鹽噪聲和高斯噪聲;
Data Aumentation 有很多好處,比如資料量較少時,用資料擴充來增加訓練資料,防止過擬合。
ImageDataGenerator
在Keras中,ImageDataGenerator就是專門做資料擴充的。
from keras.preprocessing.image import ImageDataGenerator
注:Using TensorFlow backend.
官方寫法如下:
(x_train,y_train),(x_test,y_test) = cifar10.load_data() datagen = ImageDataGenerator( featurewise_center=True,... horizontal_flip=True) # compute quantities required for featurewise normalization datagen.fit(x_train) # 使用fit_generator的【自動】訓練方法: fits the model on batches with real-time data augmentation model.fit_generator(datagen.flow(x_train,y_train,batch_size=32),steps_per_epoch=len(x_train),epochs=epochs) # 自己寫range迴圈的【手動】訓練方法 for e in range(epochs): print 'Epoch',e batches = 0 for x_batch,y_batch in datagen.flow(x_train,batch_size=32): loss = model.train(x_batch,y_batch) batches += 1 if batches >= len(x_train) / 32: # we need to break the loop by hand because # the generator loops indefinitely break
ImageDataGenerator的引數說明見官網文件。
上面兩種訓練方法的差異不討論,我們要關注的是:官方封裝的訓練集batch生成器是ImageDataGenerator物件的flow方法(或flow_from_directory),該函式返回一個和python定義相似的generator。在它前一步,資料變換是ImageDataGenerator物件的fit方法。
random_crop並未在ImageDataGenerator中內建,但引數中給了一個preprocessing_function,我們可以利用它自定義my_random_crop函式,像下面這樣寫:
def my_random_crop(image): random_arr = numpy.random.randint(img_sz-crop_sz+1,size=2) y = int(random_arr[0]) x = int(random_arr[1]) h = img_crop w = img_crop image_crop = image[y:y+h,x:x+w,:] return image_crop datagen = ImageDataGenerator( featurewise_center=False,··· preprocessing_function=my_random_crop) datagen.fit(x_train)
fit方法呼叫時將預設的變換應用到x_train的每張圖上,包括影象crop,因為是單張依次處理,每張圖的crop位置隨機。
在訓練資料(x=image,y=class_label)時這樣寫已滿足要求;
但在(x=image,y=image_mask)時該方法就不成立了。影象單張處理的緣故,一對(image,image_mask)分別crop的位置無法保持一致。
雖然官網也給出了同時變換image和mask的寫法,但它提出的方案能保證二者內建函式的變換一致,自定義函式的random變數仍是隨機的。
fit_generator
既然ImageDataGenerator和flow方法不能滿足我們的random_crop預處理要求,就在fit_generator函式處想方法修改。
先看它的定義:
def fit_generator(self,generator,samples_per_epoch,nb_epoch,verbose=1,callbacks=[],validation_data=None,nb_val_samples=None,class_weight=None,max_q_size=10,**kwargs):
第一個引數generator,可以傳入一個方法,也可以直接傳入資料集。前面的 datagen.flow() 即是Keras封裝的批量資料傳入方法。
顯然,我們可以自定義。
def generate_batch_data_random(x,y,batch_size): """分批取batch資料載入到視訊記憶體""" total_num = len(x) batches = total_num // batch_size while (True): i = randint(0,batches) x_batch = x[i*batch_size:(i+1)*batch_size] y_batch = y[i*batch_size:(i+1)*batch_size] random_arr = numpy.random.randint(img_sz-crop_sz+1,size=2) y_pos = int(random_arr[0]) x_pos = int(random_arr[1]) x_crop = x_batch[:,y_pos:y_pos+crop_sz,x_pos:x_pos+crop_sz,:] y_crop = y_batch[:,:] yield (x_crop,y_crop)
這樣寫就符合我們同組image和mask位置一致的random_crop要求。
注意:
由於沒有使用ImageDataGenerator內建的資料變換方法,資料擴充則也需要自定義;由於沒有使用flow(…,shuffle=True,)方法,每個epoch的資料打亂需要自定義。
generator自定義時要寫成死迴圈,因為在每個epoch內,generate_batch_data_random是不會重複呼叫的。
補充知識:tensorflow中的隨機裁剪函式random_crop
tf.random_crop是tensorflow中的隨機裁剪函式,可以用來裁剪圖片。我採用如下圖片進行隨機裁剪,裁剪大小為原圖的一半。
如下是實驗程式碼
import tensorflow as tf import matplotlib.image as img import matplotlib.pyplot as plt sess = tf.InteractiveSession() image = img.imread('D:/Documents/Pictures/logo3.jpg') reshaped_image = tf.cast(image,tf.float32) size = tf.cast(tf.shape(reshaped_image).eval(),tf.int32) height = sess.run(size[0]//2) width = sess.run(size[1]//2) distorted_image = tf.random_crop(reshaped_image,[height,width,3]) print(tf.shape(reshaped_image).eval()) print(tf.shape(distorted_image).eval()) fig = plt.figure() fig1 = plt.figure() ax = fig.add_subplot(111) ax1 = fig1.add_subplot(111) ax.imshow(sess.run(tf.cast(reshaped_image,tf.uint8))) ax1.imshow(sess.run(tf.cast(distorted_image,tf.uint8))) plt.show()
如下是隨機實驗兩次的結果
以上這篇Keras 在fit_generator訓練方式中加入影象random_crop操作就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。