1. 程式人生 > >AI:拿來主義——預訓練網路(一)

AI:拿來主義——預訓練網路(一)

我們已經訓練過幾個神經網路了,識別手寫數字,房價預測或者是區分貓和狗,那隨之而來就有一個問題,這些訓練出的網路怎麼用,每個問題我都需要重新去訓練網路嗎?因為程式設計師都不太喜歡做重複的事情,因此答案肯定是已經有輪子了。

我們先來介紹一個數據集,ImageNet。這就不得不提一個大名鼎鼎的華裔 AI 科學家李飛飛。

2005 年左右,李飛飛結束了他的博士生涯,開始了他的學術研究不就她就意識到了一個問題,在此之前,人們都儘可能優化演算法,認為無論資料如何,只要演算法夠好,就能做出更好的決策,李飛飛意識到了這個問題的侷限性,恰巧她還是一個行動派,她要做出一個無比龐大的資料集,儘可能描述世界上一切物體的資料集,下載圖片,給沒一張圖片做標註,簡單而無聊,當然後來這項工作放到了亞馬遜的眾包平臺上,全世界無數的人蔘與了這個偉大的專案,到此刻為止,已經有 14,197,122 張圖片(一千四百萬張),21841 個分類。在這個發展的過程中,人們也發現了這個資料集帶來的成功遠比預想的要多,甚至現在被認為最有前景的深度卷積神經網路的提出也與 ImageNet 不無關係。我忘記了誰這麼說過:“就單單這一個資料集,就可以讓李飛飛資料科學這個領域擁有一席之地”。暫且不說這麼說是否準確,但這個資料集仍然在創造新的突破。(我曾經在臺下聽過李飛飛一次演講,現在想想還覺得甚是激動,她真的充滿熱情)。

基於這個資料集,我們是不是可以訓練出一些網路,一般情況下,大家就不用耗時再去訓練網路了呢?答案是肯定的,並且在 Keras 就有個一些這樣的模型,還是內建的,Keras 就是這麼懂你,那就不用客氣了,我們拿來用就好了,謝謝啦!

特徵提取

我們之前用到的卷積神經網路都是分成了兩部分,第一部分是由池化層和卷積層組成的卷積積,第二部分是由分類器,特徵提取的含義就是第一部分不變,改變第二部分。

為什麼可以這麼做?我們之前解釋過神經網路的執行原理,跟人腦的認識過程非常類似,還記得嗎?我們還是看一看原來的圖吧。

我們可以看出來,網路識別影象是有層次結構的,比如一開始的網路層是用來識別影象或者拼裝線條的,這是通用且類似的,因此我們可以複用。而後面的分類器往往是根據具體的問題所決定的,比如識別貓或狗的眼睛就與識別桌子腿是不一樣的,因此有越靠前越具有通用性的特點。Keras 中很多的內建模型都可以直接下載,如果你沒有下載在使用的時候會自動下載:

https://github.com/fchollet/deep-learning-models/releases

我們舉一個例子,用 VGG16 去識別貓或狗,這次的解釋都比較簡單且都是以前說明過的,因此放在程式碼註釋中:

#!/usr/bin/env python3
​
import os
import time
​
import matplotlib.pyplot as plt
import numpy as np
from keras import layers
from keras import models
from keras import optimizers
from keras.applications import VGG16
from keras.preprocessing.image import ImageDataGenerator
​
​
def extract_features(directory, sample_count):
    # 圖片轉換區間
    datagen = ImageDataGenerator(rescale=1. / 255)
    batch_size = 20
    conv_base = VGG16(weights='imagenet',
                      include_top=False,
                      input_shape=(150, 150, 3))
​
    conv_base.summary()
​
    features = np.zeros(shape=(sample_count, 4, 4, 512))
    labels = np.zeros(shape=(sample_count))
    # 讀出圖片,處理成神經網路需要的資料格式,上一篇文章中有介紹
    generator = datagen.flow_from_directory(
        directory,
        target_size=(150, 150),
        batch_size=batch_size,
        class_mode='binary')
    i = 0
    for inputs_batch, labels_batch in generator:
        print(i, '/', len(generator))
        # 提取特徵
        features_batch = conv_base.predict(inputs_batch)
        features[i * batch_size: (i + 1) * batch_size] = features_batch
        labels[i * batch_size: (i + 1) * batch_size] = labels_batch
        i += 1
        if i * batch_size >= sample_count:
            break
​
    # 特徵和標籤
    return features, labels
​
​
def cat():
    base_dir = '/Users/renyuzhuo/Desktop/cat/dogs-vs-cats-small'
    train_dir = os.path.join(base_dir, 'train')
    validation_dir = os.path.join(base_dir, 'validation')
​
    # 提取出的特徵
    train_features, train_labels = extract_features(train_dir, 2000)
    validation_features, validation_labels = extract_features(validation_dir, 1000)
​
    # 對特徵進行變形展平
    train_features = np.reshape(train_features, (2000, 4 * 4 * 512))
    validation_features = np.reshape(validation_features, (1000, 4 * 4 * 512))
​
    # 定義密集連線分類器
    model = models.Sequential()
    model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512))
    model.add(layers.Dropout(0.5))
    model.add(layers.Dense(1, activation='sigmoid'))
​
    # 對模型進行配置
    model.compile(optimizer=optimizers.RMSprop(lr=2e-5),
                  loss='binary_crossentropy',
                  metrics=['acc'])
​
    # 對模型進行訓練
    history = model.fit(train_features, train_labels,
                        epochs=30,
                        batch_size=20,
                        validation_data=(validation_features, validation_labels))
​
    # 畫圖
    acc = history.history['acc']
    val_acc = history.history['val_acc']
    loss = history.history['loss']
    val_loss = history.history['val_loss']
    epochs = range(1, len(acc) + 1)
    plt.plot(epochs, acc, 'bo', label='Training acc')
    plt.plot(epochs, val_acc, 'b', label='Validation acc')
    plt.title('Training and validation accuracy')
    plt.legend()
    plt.show()
    plt.figure()
    plt.plot(epochs, loss, 'bo', label='Training loss')
    plt.plot(epochs, val_loss, 'b', label='Validation loss')
    plt.title('Training and validation loss')
    plt.legend()
    plt.show()
​
​
if __name__ == "__main__":
    time_start = time.time()
    cat()
    time_end = time.time()
    print('Time Used: ', time_end - time_start)​

有點巧合的是這裡居然看不到太多的過擬合的痕跡,其實也是有可能會有過擬合的隱患的,那樣就需要進行資料增強,與以前是一樣的,只不過這裡的區別就是用到了內建模型,模型的引數需要凍結,我們是不希望對已經訓練好的模型進行更改的,具體關鍵程式碼寫法如下:

conv_base.trainable = False
​
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))

以上就是模型複用的一種方法,我們對模型都是原封不動的拿來用,我們下一篇文章將介紹另外一種方法,對模型進行微調。

首發自公眾號:RAIS