1. 程式人生 > >Data Augment ------TensorFlow 訓練圖片處理

Data Augment ------TensorFlow 訓練圖片處理

Deep Learning 是基於一個數據的分析模型或是資料的分析方法。資料在足夠大時,其分析或事更佳,資料更佳能分析出結果。如果收集的資料過於小或是資料收集小,這樣訓練出來的結果會讓其出乎意料。我在此是針對於CNN模型,其實該模型不是說一定需要大量資料才可以得到結果,其涉及到您訓練的模型-fine-tune,採集的資料是否夠優質或是資料之間的區分度等。方法1:其實我們如果涉及到一個大類中的資料進行在此的分類時,我們可以進行拆解不一樣區域而進行多模型的輸出,最後進行一個綜合判定,該方法自己在閱讀RCNN、Faster-RCNN等論文的後感,具體自己會在MIT的bird資料集上進行測試和驗證,具體後面來續寫。方法2:還有一種方法是遷移學習模型,主要遷移imageNet資料集上的訓練。該方法目前表現不錯,自己也用於測試過在Category101上,基於ResNet50,基本在15epoch上其表現已經很好了。方法3:細小顆粒區分模型訓練,其實也是正對其少量資料訓練非常不錯的,關注過近年來的論文,在細小顆粒上區分模型也是可以進行區分的,該方法目前還是需要完善。方法4: 適用Data Augment方法進行,圖片的翻轉等,實際我們都知道,圖片實際是一個矩陣,我們實際就是將矩陣裡資料部分變換了(所謂的翻轉等)。自己在學習過程中需要擴充資料寫了一個數據擴充的類:

def floatrange(start,stop,steps):
    return [start+float(i)*(stop-start)/(float(steps)-1) for i in range(steps)]

class DataAugment:
    def __init__(self,path,shape,count=0,num1=150,num2=100,num3=100,num4=150,num5=100,num6=50,num7=200,num8=150):
        get_tensor=tf.gfile.FastGFile(path,'rb').read()
        self.image_data=tf.image.decode_jpeg(get_tensor)
        self.shape=shape
        self.count=count
        self.num1 = num1
        self.num2 = num2
        self.num3 = num3
        self.num4 = num4
        self.num5 = num5
        self.num6 = num6
        self.num7 = num7
        self.num8 = num8
    def Save_image(self,name,x):
        x=tf.image.encode_jpeg(x)
        with tf.Session() as sess :
            sess.run(tf.global_variables_initializer())
            with tf.gfile.FastGFile(name,'wb') as save:
                save.write(x.eval())
#150 1
    def get_image_central(self):
        for i in floatrange(0.4,1.0,self.num1):
            self.count=self.count+1
            t=self.count
            name=dir+str(t)+'.jpg'
            contest=tf.image.central_crop(self.image_data,i)
            contest.set_shape(self.shape)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                self.Save_image(name,contest)
        #print ('1')
#100 2
    def get_image_random_tranpose(self):
        for i in range(self.num2):
            self.count=self.count+1
            t=self.count
            name=dir+str(t)+'.jpg'
            contest=tf.image.transpose_image(self.image_data)
            contest.set_shape(self.shape)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                self.Save_image(name,contest)
        #print ('2')
#100 3
    def get_image_up_down(self):
        for i in range(self.num3):
            self.count=self.count+1
            t=self.count
            name=dir+str(t)+'.jpg'
            contest=tf.image.random_flip_up_down(self.image_data)
            contest.set_shape(self.shape)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                self.Save_image(name,contest)
        #print('3')
#150 4
    def get_image_left_right(self):
        for i in range(self.num4):
            self.count=self.count+1
            t=self.count
            name=dir+str(t)+'.jpg'
            contest=tf.image.random_flip_left_right(self.image_data)
            contest.set_shape(self.shape)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                self.Save_image(name,contest)
        #print('4')
#100 5
    def get_image_bright(self):
        for i in floatrange(-0.6,0.6,self.num5):
           self.count=self.count+1
           t=self.count
           name=dir+str(t)+'.jpg'
           contest=tf.image.adjust_brightness(self.image_data,i)
           contest.set_shape(self.shape)
           with tf.Session() as sess:
               sess.run(tf.global_variables_initializer())
               self.Save_image(name,contest)
        #print('5')
#050 6
    def get_image_contrast(self):
        for i in range(self.num6):
            index=np.random.randint(-4,4)
            self.count=self.count+1
            t=self.count
            name=dir+str(t)+'.jpg'
            contest=tf.image.adjust_contrast(self.image_data,index)
            contest.set_shape(self.shape)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                self.Save_image(name,contest)
        #print ('6')

#200 7
    def get_image_hue(self):
        for i in floatrange(0.0,0.5,self.num7):
           self.count=self.count+1
           t=self.count
           name=dir+str(t)+'.jpg'
           contest=tf.image.random_hue(self.image_data,i)
           contest.set_shape(self.shape)
           with tf.Session() as sess:
               sess.run(tf.global_variables_initializer())
               self.Save_image(name,contest)
        #print ('7')
#150 8
    def get_image_sation(self):
        for i in range(self.num8):
            index=np.random.randint(-6,10)
            self.count=self.count+1
            t=self.count
            name=dir+str(t)+'.jpg'
            contest=tf.image.adjust_saturation(self.image_data,index)
            contest.set_shape(self.shape)
            with tf.Session() as sess:
                sess.run(tf.global_variables_initializer())
                self.Save_image(name,contest)
        #print ('8')
#'1'
    def get_count(self):
        return self.count
    def run(self):
        self.get_image_random_tranpose()
        self.get_image_left_right()
        self.get_image_central()
        self.get_image_bright()
        self.get_image_hue()
        self.get_image_contrast()
        self.get_image_sation()
        self.get_image_up_down()
    def run_1(self):
        #self.get_image_bright()
        self.get_image_random_tranpose()
        self.get_image_left_right()
        #self.get_image_contrast()
        #self.get_image_up_down()
        #self.get_image_central()
        #self.get_image_hue()
        self.get_image_hue()
        self.get_image_sation()
        #self.get_image_sation()
    def show(self):
        x=tf.image.encode_jpeg(self.image_data)
        with tf.Session() as sess:
            sess.run(tf.global_variables_initializer())
            plt.imread(x.eval())
            plt.show()

def get_image_path(path):
    list_name = []
    if os.path.isdir(path):
        for file in os.listdir(path):
            if file.endswith('.jpg') and not file.startswith('.'):
                filename = path+'/'+file
                list_name.append(filename)
    return list_name
def get_dir(dir):
    dirs = []
    if os.path.isdir(dir):
        for path in os.listdir(dir):
            paths = dir +'/'+path
            if os.path.isdir(paths):
                dirs.append(paths)
    return dirs
def make_dirs(path):
    name = 'file'
    for i in range(250):
        names = path +name +str(i)
        if not os._exists(names):
            os.makedirs(names)

呼叫用例為:


if __name__=='__main__':

    test=DataAugment('/Users/josen/Desktop/001_0003.jpg',[64,64,3],0,2,2,2,2,2,2,2,2)
    test.run_1()
    t=test.get_count()
    print ('the final data is ',str(t))

    cout = 0
    path = '/Users/josen/Desktop/256_ObjectCategories/'+input('dir:')
    dir = input('file-path:')
    list_name = get_image_path(path)
    for i in range(len(list_name)):
        test = DataAugment(list_name[i],[64,64,3],cout,2,2,2,2,2,2,2,2)
        test.run_1()
        cout = test.get_count()
    print(cout)

    root = '/Users/josen/Desktop/Object256/'+input('file:--')
    #list_dir = get_dir(root)
    #print(len(list_dir))
    cout = 0
    #for i in range(len(list_dir)):
    list_name = get_image_path(root)
    #dir = list_dir[i] +'/'
    #nfor = list_name.split('/')[4].split('.')[0]
    for j in range(len(list_name)):
            test = DataAugment(list_name[j],[64,64,3],cout,2,2,2,2,2,2,2,2)
            test.run_1()
            cout = test.get_count()
            if cout == 650:
                break
    print(cout)