1. 程式人生 > >tensorflow實現資料增強(隨機裁剪、翻轉、對比度設定、亮度設定)

tensorflow實現資料增強(隨機裁剪、翻轉、對比度設定、亮度設定)

資料增強(Data Augmentation):是指對圖片進行隨機的旋轉、翻轉、裁剪、隨機設定圖片的亮度和對比度以及對資料進行標準化(資料的均值為0,方差為1)。通過這些操作,我們可以獲得更多的圖片樣本,原來的一張圖片可以變為多張圖片,擴大了樣本容量,對於提高模型的準確率和提升模型的泛化能力非常有幫助,在進行資料增強的同時也會需要消耗大量的系統資源。

利用opencv來讀取圖片,然後利用tensorflow來對圖片進行增強處理,最後再通過matplotlib來顯示圖片,需要注意的是matplotlib顯示圖片的時候是使用RGB通道順序來顯示圖片,而opencv則是採用BGR的順序來處理圖片的,所以在對圖片進行imshow之前需要先進行通道轉換。

1、隨機裁剪

原始圖片的大小為300×300,將圖片隨機裁剪為280×280,通道大小不變。

import tensorflow as tf
import cv2
import matplotlib.pyplot as plt
#用來正常顯示中文
plt.rcParams["font.sans-serif"]=["SimHei"]

if __name__ == "__main__":
    img = cv2.imread("img/img.jpg")
    #將圖片進行隨機裁剪為280×280
    crop_img = tf.random_crop(img,[280,280,3])
    sess = tf.InteractiveSession()
    #顯示圖片
    # cv2.imwrite("img/crop.jpg",crop_img.eval())
    plt.figure(1)
    plt.subplot(121)
    #將圖片由BGR轉成RGB
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    plt.imshow(img)
    plt.title("原始圖片")
    plt.subplot(122)
    crop_img = cv2.cvtColor(crop_img.eval(),cv2.COLOR_BGR2RGB)
    plt.title("裁剪後的圖片")
    plt.imshow(crop_img)
    plt.show()
    sess.close()

2、隨機翻轉

對圖片的水平方向和垂直方向進行隨機翻轉。

    img = cv2.imread("img/img.jpg")
    #將圖片隨機進行水平翻轉
    h_flip_img = tf.image.random_flip_left_right(img)
    #將圖片隨機進行垂直翻轉
    v_flip_img = tf.image.random_flip_up_down(img)
    sess = tf.InteractiveSession()
    #通道轉換
    img = cv2.cvtColor(img,cv2.COLOR_BGR2RGB)
    h_flip_img = cv2.cvtColor(h_flip_img.eval(),cv2.COLOR_BGR2RGB)
    v_flip_img = cv2.cvtColor(v_flip_img.eval(),cv2.COLOR_BGR2RGB)
    #顯示圖片
    plt.figure(1)
    plt.subplot(131)
    plt.title("水平翻轉")
    plt.imshow(h_flip_img)
    plt.subplot(132)
    plt.title("垂直翻轉")
    plt.imshow(v_flip_img)
    plt.subplot(133)
    plt.title("原始圖片")
    plt.imshow(img)
    plt.show()

3、隨機亮度、對比度、色度、飽和度的設定

    #隨機設定圖片的亮度
    random_brightness = tf.image.random_brightness(img,max_delta=30)
    #隨機設定圖片的對比度
    random_contrast = tf.image.random_contrast(img,lower=0.2,upper=1.8)
    #隨機設定圖片的色度
    random_hue = tf.image.random_hue(img,max_delta=0.3)
    #隨機設定圖片的飽和度
    random_satu = tf.image.random_saturation(img,lower=0.2,upper=1.8)

4、圖片的標準化

標準化後圖片的均值為0,方差為1

    img = cv2.imread("img/img.jpg")
    #將圖片進行標準化
    std_img = tf.image.per_image_standardization(img)
    sess = tf.InteractiveSession()
    print(std_img.eval())