1. 程式人生 > 其它 >七、VGG實現鳥類資料庫分類

七、VGG實現鳥類資料庫分類

目錄

前文

加利福尼亞理工學院鳥類資料庫分類

資料生成器

from keras.preprocessing.image import ImageDataGenerator

IMSIZE = 224
train_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory('../../data/data_vgg/train',
                                                                           target_size=(IMSIZE, IMSIZE),
                                                                           batch_size=20,
                                                                           class_mode='categorical'
                                                                           )

validation_generator = ImageDataGenerator(rescale=1. / 255).flow_from_directory('../../data/data_vgg/test',
                                                                                target_size=(IMSIZE, IMSIZE),
                                                                                batch_size=20,
                                                                                class_mode='categorical'
                                                                                )
                                                                           )

影象顯示

from matplotlib import pyplot as plt

plt.figure()
fig, ax = plt.subplots(2, 5)
fig.set_figheight(6)
fig.set_figwidth(15)
ax = ax.flatten()
X, Y = next(validation_generator)
for i in range(15): ax[i].imshow(X[i, :, :, ])

VGG模型構建

from keras.layers import Conv2D, MaxPooling2D
from keras.layers import Flatten, Dense, Input, Activation
from keras import Model
from keras.layers import GlobalAveragePooling2D

input_shape = (IMSIZE, IMSIZE, 3)
input_layer = Input(input_shape)
x = input_layer

x = Conv2D(64, [3, 3], padding="same", activation='relu')(x)
x = Conv2D(64, [3, 3], padding="same", activation='relu')(x)
x = MaxPooling2D((2, 2))(x)

x = Conv2D(128, [3, 3], padding="same", activation='relu')(x)
x = Conv2D(128, [3, 3], padding="same", activation='relu')(x)
x = Conv2D(128, [3, 3], padding="same", activation='relu')(x)
x = MaxPooling2D((2, 2))(x)

x = Conv2D(256, [3, 3], padding="same", activation='relu')(x)
x = Conv2D(256, [3, 3], padding="same", activation='relu')(x)
x = Conv2D(256, [3, 3], padding="same", activation='relu')(x)
x = MaxPooling2D((2, 2))(x)

x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = MaxPooling2D((2, 2))(x)

x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = Conv2D(512, [3, 3], padding="same", activation='relu')(x)
x = MaxPooling2D((2, 2))(x)

x = GlobalAveragePooling2D()(x)

x = Dense(315)(x)
x = Activation('softmax')(x)
output_layer = x
model_vgg16 = Model(input_layer, output_layer)
model_vgg16.summary()

VGG模型編譯與擬合

from keras.optimizers import Adam

model_vgg16.compile(loss='categorical_crossentropy',
                    optimizer=Adam(lr=0.001),
                    metrics=['accuracy'])
model_vgg16.fit_generator(train_generator,
                          epochs=20,
                          validation_data=validation_generator)

注意:

因為自己是使用tensorflow-GPU版本,自己電腦是1050Ti,4G視訊記憶體。實際執行時候batch_size設定不到15大小,太大了就視訊記憶體資源不足。

但是batch_size太小,總的資料集較大較多,所以最後消耗時間就較長。

所以為了效率和燒顯示卡,請酌情考慮

資料集來源:kaggle平臺315種鳥類:315 Bird Species - Classification | Kaggle

GitHub下載地址:

Tensorflow1.15深度學習

本文來自部落格園,作者:李好秀,轉載請註明原文連結:https://www.cnblogs.com/lehoso/p/15614793.html