1. 程式人生 > >Keras:在預訓練的網路上fine-tune

Keras:在預訓練的網路上fine-tune

準備

fine-tune的三個步驟:

  • 搭建vgg-16並載入權重;
  • 將之前定義的全連線網路載入到模型頂部,並載入權重;
  • 凍結vgg16網路的一部分引數.

在之前的Keras:自建資料集影象分類的模型訓練、儲存與恢復裡製作了實驗用的資料集並初步進行了訓練.然後在Keras:使用預訓練網路的bottleneck特徵中定義並訓練了要使用全連線網路,並將網路權重儲存到了bottleneck_fc_model.h5檔案中.

fine-tune過程

根據keras中…/keras/applications/vgg16.py的VGG16模型形式,構造VGG16模型的卷積部分,並載入權重(vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5).然後新增預訓練好的模型.訓練時凍結最後一個卷積塊前的卷基層引數.

示例:

#!/usr/bin/python
# coding:utf8

from keras.models import Sequential
from keras import optimizers
from keras.preprocessing.image import ImageDataGenerator
from keras.layers import Flatten, Dense, Dropout, Conv2D, MaxPooling2D
from keras import backend as K
K.set_image_dim_ordering('th')


# 構造VGG16模型
model = Sequential() # Block 1 model.add(Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv1', input_shape=(3, 150, 150))) model.add(Conv2D(64, (3, 3), activation='relu', padding='same', name='block1_conv2')) model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block1_pool')) # Block 2
model.add(Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv1')) model.add(Conv2D(128, (3, 3), activation='relu', padding='same', name='block2_conv2')) model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block2_pool')) # Block 3 model.add(Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv1')) model.add(Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv2')) model.add(Conv2D(256, (3, 3), activation='relu', padding='same', name='block3_conv3')) model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block3_pool')) # Block 4 model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv1')) model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv2')) model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='block4_conv3')) model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block4_pool')) # Block 5 model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv1')) model.add(Conv2D(512, (3, 3), activation='relu', padding='same', name='block5_conv2')) model.add(MaxPooling2D((2, 2), strides=(2, 2), name='block5_pool')) model.load_weights('vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5',by_name=True) model.summary() # 在初始化好的VGG網路上新增預訓練好的模型 top_model = Sequential() top_model.add(Flatten(input_shape=model.output_shape[1:])) # (4,4,512) top_model.add(Dense(256, activation='relu')) top_model.add(Dropout(0.5)) top_model.add(Dense(1, activation='sigmoid')) top_model.load_weights('bottleneck_fc_model.h5',by_name=True) model.add(top_model) # 將最後一個卷積塊前的卷基層引數凍結,把隨後卷積塊前的權重設定為不可訓練(權重不會更新) for layer in model.layers[:25]: layer.trainable = False model.compile(loss='binary_crossentropy', optimizer=optimizers.SGD(lr=1e-4, momentum=0.9), metrics=['accuracy']) # 以低學習率進行訓練 train_datagen = ImageDataGenerator(rescale=1./255, shear_range=0.2, zoom_range=0.2, horizontal_flip=True) test_datagen = ImageDataGenerator(rescale=1./255) train_generator = train_datagen.flow_from_directory('train', target_size=(150,150), batch_size=32, class_mode='binary') validation_generator = test_datagen.flow_from_directory('validation', target_size=(150,150), batch_size=32, class_mode='binary') model.fit_generator(train_generator, steps_per_epoch=10, epochs=50, validation_data=validation_generator, validation_steps=10)

輸出:

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
block1_conv1 (Conv2D)        (None, 64, 150, 150)      1792      
_________________________________________________________________
block1_conv2 (Conv2D)        (None, 64, 150, 150)      36928     
_________________________________________________________________
block1_pool (MaxPooling2D)   (None, 64, 75, 75)        0         
_________________________________________________________________
block2_conv1 (Conv2D)        (None, 128, 75, 75)       73856     
_________________________________________________________________
block2_conv2 (Conv2D)        (None, 128, 75, 75)       147584    
_________________________________________________________________
block2_pool (MaxPooling2D)   (None, 128, 37, 37)       0         
_________________________________________________________________
block3_conv1 (Conv2D)        (None, 256, 37, 37)       295168    
_________________________________________________________________
block3_conv2 (Conv2D)        (None, 256, 37, 37)       590080    
_________________________________________________________________
block3_conv3 (Conv2D)        (None, 256, 37, 37)       590080    
_________________________________________________________________
block3_pool (MaxPooling2D)   (None, 256, 18, 18)       0         
_________________________________________________________________
block4_conv1 (Conv2D)        (None, 512, 18, 18)       1180160   
_________________________________________________________________
block4_conv2 (Conv2D)        (None, 512, 18, 18)       2359808   
_________________________________________________________________
block4_conv3 (Conv2D)        (None, 512, 18, 18)       2359808   
_________________________________________________________________
block4_pool (MaxPooling2D)   (None, 512, 9, 9)         0         
_________________________________________________________________
block5_conv1 (Conv2D)        (None, 512, 9, 9)         2359808   
_________________________________________________________________
block5_conv2 (Conv2D)        (None, 512, 9, 9)         2359808   
_________________________________________________________________
block5_pool (MaxPooling2D)   (None, 512, 4, 4)         0         
=================================================================
Total params: 12,354,880
Trainable params: 12,354,880
Non-trainable params: 0
_________________________________________________________________
Found 60 images belonging to 2 classes.
Found 60 images belonging to 2 classes.
Epoch 1/50

 1/10 [==>...........................] - ETA: 6:57 - loss: 0.7880 - acc: 0.3929
 2/10 [=====>........................] - ETA: 6:23 - loss: 0.7920 - acc: 0.4152
 3/10 [========>.....................] - ETA: 5:25 - loss: 0.8292 - acc: 0.3839
 4/10 [===========>..................] - ETA: 4:47 - loss: 0.8184 - acc: 0.3895
 5/10 [==============>...............] - ETA: 3:59 - loss: 0.8159 - acc: 0.3929
 6/10 [=================>............] - ETA: 3:08 - loss: 0.8001 - acc: 0.4048
 7/10 [====================>.........] - ETA: 2:18 - loss: 0.8094 - acc: 0.4184
 8/10 [=======================>......] - ETA: 1:32 - loss: 0.8031 - acc: 0.4247
 9/10 [==========================>...] - ETA: 46s - loss: 0.8041 - acc: 0.4296 
10/10 [==============================] - 899s 90s/step - loss: 0.8125 - acc: 0.4260 - val_loss: 0.8145 - val_acc: 0.4000
Epoch 2/50

 1/10 [==>...........................] - ETA: 6:55 - loss: 0.8487 - acc: 0.4062
 2/10 [=====>........................] - ETA: 5:50 - loss: 0.8443 - acc: 0.4353
 3/10 [========>.....................] - ETA: 5:08 - loss: 0.8430 - acc: 0.4256
 4/10 [===========>..................] - ETA: 4:18 - loss: 0.8258 - acc: 0.4263
 5/10 [==============>...............] - ETA: 3:32 - loss: 0.8310 - acc: 0.4339
 6/10 [=================>............] - ETA: 2:53 - loss: 0.8266 - acc: 0.4397
 7/10 [====================>.........] - ETA: 2:11 - loss: 0.8270 - acc: 0.4305
 8/10 [=======================>......] - ETA: 1:26 - loss: 0.8220 - acc: 0.4347
  9/10 [==========================>...] - ETA: 43s - loss: 0.8311 - acc: 0.4340 

 ......
 ......