使用model.fit_generator方法進行訓練(自己的訓練集-多分類)
阿新 • • 發佈:2018-12-09
我們在使用model.fit()進行訓練的時候, 在這之前你肯定會有訓練集的x_img_train,y_label_train兩個引數。
fit(x=None, y=None, batch_size=None, epochs=1, verbose=1, callbacks=None, validation_split=0.0, validation_data=None, shuffle=True, class_weight=None, sample_weight=None, initial_epoch=0, steps_per_epoch=None, validation_steps=None)
但是當我們使用model.fit_generator()的時候,它的方法是這樣的:
fit_generator(generator, steps_per_epoch=None, epochs=1, verbose=1, callbacks=None, validation_data=None, validation_steps=None, class_weight=None, max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
可以看到它要求傳入的引數是一個generator.官網說的很清楚,( 不清楚的可以看官網)這裡的generator是一個生成器,主要是訓練自己的資料,並且資料非常多的時候可以不用把資料全部載入進記憶體,而是用生成器自己一點點讀取。大大提高的執行效率。
下面是這個生成器的生成方法:
#這是訓練集的生成器 train_datagen = ImageDataGenerator( rescale=1. / 255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True) ## 訓練圖片生成器 train_generator = train_datagen.flow_from_directory( train_data_dir,#訓練樣本地址 target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical') #多分類 test_datagen = ImageDataGenerator(rescale=1. / 255) ##驗證集的生成器 validation_generator = test_datagen.flow_from_directory( validation_data_dir,#驗證樣本地址 target_size=(img_width, img_height), batch_size=batch_size, class_mode='categorical', shuffle=False) #多分類
好了,有了這個train_generator生成器我們就可以入入fit_generator(...)裡面進行訓練了。
對了,這裡說明下train_data_dir / validation_data_dir 是我本機的訓練集與驗證集的地址。
目錄結構形似:
'''
data/train/
1/
001.jpg
002.jpg
...
2/
001.jpg
002.jpg
...
data/validation/
1/
001.jpg
002.jpg
...
2/
001.jpg
002.jpg
...
'''