七、VGG實現鳥類資料庫分類
阿新 • • 發佈:2021-11-28
目錄
前文
- 一、Windows系統下安裝Tensorflow2.x(2.6)
- 二、深度學習-讀取資料
- 三、Tensorflow影象處理預算
- 四、線性迴歸模型的tensorflow實現
- 五、深度學習-邏輯迴歸模型
- 六、AlexNet實現中文字型識別——隸書和行楷
加利福尼亞理工學院鳥類資料庫分類
資料生成器
from keras.preprocessing.image import ImageDataGenerator IMSIZE = 224 train_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory('../../data/data_vgg/train', target_size=(IMSIZE, IMSIZE), batch_size=20, class_mode='categorical' ) validation_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory('../../data/data_vgg/test', target_size=(IMSIZE, IMSIZE), batch_size=20, class_mode='categorical' ) )
影象顯示
from matplotlib import pyplot as plt
plt.figure()
fig, ax = plt.subplots(2, 5)
fig.set_figheight(6)
fig.set_figwidth(15)
ax = ax.flatten()
X, Y = next(validation_generator)
for i in range(15): ax[i].imshow(X[i, :, :, ])
VGG模型構建
from keras.layers import Conv2D, MaxPooling2D from keras.layers import Flatten, Dense, Input, Activation from keras import Model from keras.layers import GlobalAveragePooling2D input_shape = (IMSIZE, IMSIZE, 3) input_layer = Input(input_shape) x = input_layer x = Conv2D(64, [3, 3], padding="same", activation='relu')(x) x = Conv2D(64, [3, 3], padding="same", activation='relu')(x) x = MaxPooling2D((2, 2))(x) x = Conv2D(128, [3, 3], padding="same", activation='relu')(x) x = Conv2D(128, [3, 3], padding="same", activation='relu')(x) x = Conv2D(128, [3, 3], padding="same", activation='relu')(x) x = MaxPooling2D((2, 2))(x) x = Conv2D(256, [3, 3], padding="same", activation='relu')(x) x = Conv2D(256, [3, 3], padding="same", activation='relu')(x) x = Conv2D(256, [3, 3], padding="same", activation='relu')(x) x = MaxPooling2D((2, 2))(x) x = Conv2D(512, [3, 3], padding="same", activation='relu')(x) x = Conv2D(512, [3, 3], padding="same", activation='relu')(x) x = Conv2D(512, [3, 3], padding="same", activation='relu')(x) x = MaxPooling2D((2, 2))(x) x = Conv2D(512, [3, 3], padding="same", activation='relu')(x) x = Conv2D(512, [3, 3], padding="same", activation='relu')(x) x = Conv2D(512, [3, 3], padding="same", activation='relu')(x) x = MaxPooling2D((2, 2))(x) x = GlobalAveragePooling2D()(x) x = Dense(315)(x) x = Activation('softmax')(x) output_layer = x model_vgg16 = Model(input_layer, output_layer) model_vgg16.summary()
VGG模型編譯與擬合
from keras.optimizers import Adam model_vgg16.compile(loss='categorical_crossentropy', optimizer=Adam(lr=0.001), metrics=['accuracy']) model_vgg16.fit_generator(train_generator, epochs=20, validation_data=validation_generator)
注意:
因為自己是使用tensorflow-GPU版本,自己電腦是1050Ti,4G視訊記憶體。實際執行時候batch_size設定不到15大小,太大了就視訊記憶體資源不足。
但是batch_size太小,總的資料集較大較多,所以最後消耗時間就較長。
所以為了效率和燒顯示卡,請酌情考慮
資料集來源:kaggle平臺315種鳥類:315 Bird Species - Classification | Kaggle
GitHub下載地址:
本文來自部落格園,作者:李好秀,轉載請註明原文連結:https://www.cnblogs.com/lehoso/p/15614793.html