1. 程式人生 > >基本資料增強處理

基本資料增強處理

資料增強主要包含如下方式: 1.旋轉: 可通過在原圖上先放大影象,然後剪下影象得到。 2.平移:先放大影象,然後水平或垂直偏移位置剪下 3.縮放:縮放影象 4.水平翻轉:以過影象中心的豎直軸為對稱軸,將左、右兩邊畫素交換 填充模式:最近鄰方式 5.顏色色差(飽和度、亮度、對比度、 銳度等)
相關Python原始碼: """資料增強 1. 翻轉變換 flip 2. 隨機修剪 random crop 3. 色彩抖動 color jittering 4. 平移變換 shift 5. 尺度變換 scale 6. 對比度變換 contrast 7. 噪聲擾動 noise 8. 旋轉變換/反射變換 Rotation/reflection from PIL import Image, ImageEnhance, ImageOps, ImageFile import numpy as np import random import threading, os, time import logging logger = logging.getLogger(__name__) ImageFile.LOAD_TRUNCATED_IMAGES = True class DataAugmentation: """ 包含資料增強的八種方式 """ def __init__(self): pass @staticmethod def openImage(image): return Image.open(image, mode="r") @staticmethod def randomRotation(image, mode=Image.BICUBIC): """ 對影象進行隨機任意角度(0~360度)旋轉 :param mode 鄰近插值,雙線性插值,雙三次B樣條插值(default) :param image PIL的影象image :return: 旋轉轉之後的影象 """ random_angle = np.random.randint(1, 360) return image.rotate(random_angle, mode) @staticmethod def randomCrop(image): """ 對影象隨意剪下,考慮到影象大小範圍(68,68),使用一個一個大於(36*36)的視窗進行截圖 :param image: PIL的影象image :return: 剪下之後的影象 """ image_width = image.size[0] image_height = image.size[1] crop_win_size = np.random.randint(40, 68) random_region = ( (image_width - crop_win_size) >> 1, (image_height - crop_win_size) >> 1, (image_width + crop_win_size) >> 1, (image_height + crop_win_size) >> 1) return image.crop(random_region) @staticmethod def randomColor(image): """ 對影象進行顏色抖動 :param image: PIL的影象image :return: 有顏色色差的影象image """ random_factor = np.random.randint(0, 31) / 10. # 隨機因子 color_image = ImageEnhance.Color(image).enhance(random_factor) # 調整影象的飽和度 random_factor = np.random.randint(10, 21) / 10. # 隨機因子 brightness_image = ImageEnhance.Brightness(color_image).enhance(random_factor) # 調整影象的亮度 random_factor = np.random.randint(10, 21) / 10. # 隨機因1子 contrast_image = ImageEnhance.Contrast(brightness_image).enhance(random_factor) # 調整影象對比度 random_factor = np.random.randint(0, 31) / 10. # 隨機因子 return ImageEnhance.Sharpness(contrast_image).enhance(random_factor) # 調整影象銳度 @staticmethod def randomGaussian(image, mean=0.2, sigma=0.3): """ 對影象進行高斯噪聲處理 :param image: :return: """ def gaussianNoisy(im, mean=0.2, sigma=0.3): """ 對影象做高斯噪音處理 :param im: 單通道影象 :param mean: 偏移量 :param sigma: 標準差 :return: """ for _i in range(len(im)): im[_i] += random.gauss(mean, sigma) return im # 將影象轉化成陣列 img = np.asarray(image) img.flags.writeable = True # 將陣列改為讀寫模式 width, height = img.shape[:2] img_r = gaussianNoisy(img[:, :, 0].flatten(), mean, sigma) img_g = gaussianNoisy(img[:, :, 1].flatten(), mean, sigma) img_b = gaussianNoisy(img[:, :, 2].flatten(), mean, sigma) img[:, :, 0] = img_r.reshape([width, height]) img[:, :, 1] = img_g.reshape([width, height]) img[:, :, 2] = img_b.reshape([width, height]) return Image.fromarray(np.uint8(img)) @staticmethod def saveImage(image, path): image.save(path) def makeDir(path): try: if not os.path.exists(path): if not os.path.isfile(path): # os.mkdir(path) os.makedirs(path) return 0 else: return 1 except Exception, e: print str(e) return -2 def imageOps(func_name, image, des_path, file_name, times=5): funcMap = {"randomRotation": DataAugmentation.randomRotation, "randomCrop": DataAugmentation.randomCrop, "randomColor": DataAugmentation.randomColor, "randomGaussian": DataAugmentation.randomGaussian } if funcMap.get(func_name) is None: logger.error("%s is not exist", func_name) return -1 for _i in range(0, times, 1): new_image = funcMap[func_name](image) DataAugmentation.saveImage(new_image, os.path.join(des_path, func_name + str(_i) + file_name)) opsList = {"randomRotation", "randomCrop", "randomColor", "randomGaussian"} def threadOPS(path, new_path): """ 多執行緒處理事務 :param src_path: 資原始檔 :param des_path: 目的地檔案 :return: """ if os.path.isdir(path): img_names = os.listdir(path) else: img_names = [path] for img_name in img_names: print img_name tmp_img_name = os.path.join(path, img_name) if os.path.isdir(tmp_img_name): if makeDir(os.path.join(new_path, img_name)) != -1: threadOPS(tmp_img_name, os.path.join(new_path, img_name)) else: print 'create new dir failure' return -1 # os.removedirs(tmp_img_name) elif tmp_img_name.split('.')[1] != "DS_Store": # 讀取檔案並進行操作 image = DataAugmentation.openImage(tmp_img_name) threadImage = [0] * 5 _index = 0 for ops_name in opsList: threadImage[_index] = threading.Thread(target=imageOps, args=(ops_name, image, new_path, img_name,)) threadImage[_index].start() _index += 1 time.sleep(0.2) if __name__ == '__main__': threadOPS("/home/pic-image/train/12306train", "/home/pic-image/train/12306train3")