1. 程式人生 > >AI:拿來主義——預訓練網路(二)

AI:拿來主義——預訓練網路(二)

上一篇文章我們聊的是使用預訓練網路中的一種方法,特徵提取,今天我們討論另外一種方法,微調模型,這也是遷移學習的一種方法。

微調模型

為什麼需要微調模型?我們猜測和之前的實驗,我們有這樣的共識,資料量越少,網路的特徵節點越多,會越容易導致過擬合,這當然不是我們所希望的,但對於那些預先訓練好的模型,還有可能最終無法很好的完成所要做的工作,因此我們還需要對其更改,基於此原因,我們需要做的就是拿來一個訓練好的模型,更改其中更加抽象的層,即網路後面的層,然後再採用新的分類器,這樣可以比較好的解決上面所提出的過擬合問題了。

進行微調網路的步驟是:

  1. 在已經訓練好的網路(基網路)基礎上,新增自定義的層;

  2. 凍結基網路並訓練新新增的層;

  3. 凍結基網路的一部分層,另一部分可訓練;

  4. 聯合訓練解凍的這些層和新增的部分。

我們上一篇提到的方法就可以完成前兩個步驟,接下來我們看如何解決後兩個步驟。這裡我們還要更明確一下調整的層數如果過多會帶來什麼問題:隨著可變層數的增多,過擬合的風險會隨之加大。還要明確調整網路中識別畫素和線條的層不如調整識別耳朵的層更有效,因為不論是識別貓還是桌子識別線條的方法層更通用。

完成這項任務所需要寫的程式碼也是很簡單的,就是設定模型是可訓練的,然後遍歷網路的每一層,針對每一層分別設定是否是可訓練的,直到 layer_name 層,前面的層都是不可訓練的:

conv_base.trainable = True
set_trainable = False
for layer in conv_base.layers:
    if layer.name == 'layer_name':
        set_trainable = True
    if set_trainable:
        layer.trainable = True
    else:
        layer.trainable = False

這裡是關鍵部分程式碼,老規矩,最後將給出全部程式碼,我們先來看看結果:

需要注意一下這裡的資料,在開始的時候不穩定,迅速爬升,因此縱座標的資料沒有那麼好,但我們仔細看一下後期的資料,訓練精度和驗證精度都在百分之九十到百分之百,驗證精度一直有一些波動,是網路的一些噪聲引起的,我不想去強制讓它們那麼漂亮了,一是因為訓練時間會比較長,而是因為我覺得沒有特別大的必要,波動的最高點和最低點都在可接受的範圍內,應該把關注點放在更重要的問題上去。

基於本篇文章和上一篇文章,我們做個小結:

  1. 計算機視覺領域中,卷積神經網路的表現非常不錯,並且在資料集較小的情況下,表現讓人是非常優秀的。

  2. 資料增強是很好的避免過擬合的方法,過擬合產生的主要原因可能是資料量太少或者是引數過多。

  3. 特徵提取可以比較好的將現有的神經網路應用於小型資料集,還可以使用微調的方式進行優化。

我們看看程式碼吧,這裡還有一個建議,如果可能儘量使用 GPU 去做網路模型的訓練,CPU 在現階段處理這些問題會有點力不從心,耗時較長,讀者也可以考慮減少一些資料量加快速度,但要避免過擬合,請讀者心中記住此類問題,在遇到問題的時候是一個方向(當然,筆者是非常慘的,沒有好用的 GPU,因此等待資料畫圖截圖是非常痛苦的一件事):

#!/usr/bin/env python3
​
import os
import time
​
import matplotlib.pyplot as plt
from keras import layers
from keras import models
from keras import optimizers
from keras.applications import VGG16
from keras.preprocessing.image import ImageDataGenerator
​
​
def cat():
    base_dir = '/Users/renyuzhuo/Desktop/cat/dogs-vs-cats-small'
    train_dir = os.path.join(base_dir, 'train')
    validation_dir = os.path.join(base_dir, 'validation')
​
    train_datagen = ImageDataGenerator(
        rescale=1. / 255,
        rotation_range=40,
        width_shift_range=0.2,
        height_shift_range=0.2,
        shear_range=0.2,
        zoom_range=0.2,
        horizontal_flip=True,
        fill_mode='nearest')
​
    test_datagen = ImageDataGenerator(rescale=1. / 255)
​
    train_generator = train_datagen.flow_from_directory(
        train_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='binary')
​
    validation_generator = test_datagen.flow_from_directory(
        validation_dir,
        target_size=(150, 150),
        batch_size=20,
        class_mode='binary')
​
    # 定義密集連線分類器
    conv_base = VGG16(weights='imagenet',
                      include_top=False,
                      input_shape=(150, 150, 3))
    conv_base.trainable = True
    set_trainable = False
    for layer in conv_base.layers:
        if layer.name == 'block5_conv1':
            set_trainable = True
        if set_trainable:
            layer.trainable = True
        else:
            layer.trainable = False
    model = models.Sequential()
    model.add(conv_base)
    model.add(layers.Flatten())
    model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(1, activation='sigmoid'))
​
    conv_base.summary()
​
    # 對模型進行配置
    model.compile(loss='binary_crossentropy',
                  optimizer=optimizers.RMSprop(lr=1e-5),
                  metrics=['acc'])
​
    # 對模型進行訓練
    history = model.fit_generator(
        train_generator,
        steps_per_epoch=100,
        epochs=100,
        validation_data=validation_generator,
        validation_steps=50)
​
    # 畫圖
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = range(len(acc))
    plt.plot(epochs, acc, 'bo', label='Training acc')
    plt.plot(epochs, val_acc, 'b', label='Validation acc')
    plt.title('Training and validation accuracy')
    plt.legend()
    plt.show()
    plt.figure()
    plt.plot(epochs, loss, 'bo', label='Training loss')
    plt.plot(epochs, val_loss, 'b', label='Validation loss')
    plt.title('Training and validation loss')
    plt.legend()
    plt.show()
​
​
if __name__ == "__main__":
    time_start = time.time()
    cat()
    time_end = time.time()
    print('Time Used: ', time_end - time_start)

本文首發自公眾號:RAIS