遷移學習 colab 完整示例:fruits-360 資料集
阿新 • • 發佈:2020-11-21
這裡當前目錄下已經有fruits-360
這個資料集. 關於呼叫資料集的方法可以檢視我另一篇文章.
準備
import tensorflow as tf
import tensorflow.keras as keras
from tensorflow.keras.preprocessing.image import load_img, img_to_array, array_to_img, ImageDataGenerator
建立 Generator
建立 ImageDataGenerator
. 由於這個資料集足夠大, 所以不需要進行 image augmentation.
train_datagen = ImageDataGenerator(rescale=1./255) test_datagen = ImageDataGenerator(rescale=1./255) train_generator = train_datagen.flow_from_directory( "fruits-360/Training", target_size=(100, 100), batch_size=32, class_mode='categorical') validation_generator = test_datagen.flow_from_directory( "fruits-360/Test", target_size=(100, 100), batch_size=32, class_mode='categorical')
執行後看到如下輸出表示建立成功.
Found 67692 images belonging to 131 classes.
Found 22688 images belonging to 131 classes.
模型
這裡使用的是 Xception 模型.
from tensorflow.keras.applications.xception import preprocess_input from tensorflow.keras.applications.xception import decode_predictions from tensorflow.keras.applications.xception import Xception
tf.keras.backend.clear_session() base_model = tf.keras.applications.Xception( weights='imagenet', # Load weights pre-trained on ImageNet. input_shape=(100, 100, 3), include_top=False) # Do not include the ImageNet classifier at the top. input_layer = tf.keras.Input(shape=(100, 100, 3)) base_model.trainable = False # x = data_augmentation(input_layer) x = base_model(input_layer, training = False) x = tf.keras.layers.GlobalAveragePooling2D()(x) x = tf.keras.layers.Dense(64, activation = 'relu')(x) x = tf.keras.layers.Dropout(0.2)(x) # Regularize with dropout output_layer = tf.keras.layers.Dense(131, activation = 'softmax')(x) model = tf.keras.Model(input_layer, output_layer) model.summary()
base_model.trainable = False
將會凍結 Xception 模型的權重, 在訓練中不會被更新. 即使用已經訓練好的權重.
x = base_model(input_layer, training = False)
中training=False
可以確保 base模型處於 Inference phase, 而不是Training phase.
opt = tf.keras.optimizers.Adam(learning_rate=0.01)
model.compile(loss='categorical_crossentropy', optimizer=opt, metrics = ['accuracy'])
model.fit(train_generator, epochs=5, steps_per_epoch = 67692//32, validation_data=validation_generator)