AI:拿來主義——預訓練網路(一)
我們已經訓練過幾個神經網路了,識別手寫數字,房價預測或者是區分貓和狗,那隨之而來就有一個問題,這些訓練出的網路怎麼用,每個問題我都需要重新去訓練網路嗎?因為程式設計師都不太喜歡做重複的事情,因此答案肯定是已經有輪子了。
我們先來介紹一個數據集,ImageNet。這就不得不提一個大名鼎鼎的華裔 AI 科學家李飛飛。
2005 年左右,李飛飛結束了他的博士生涯,開始了他的學術研究不就她就意識到了一個問題,在此之前,人們都儘可能優化演算法,認為無論資料如何,只要演算法夠好,就能做出更好的決策,李飛飛意識到了這個問題的侷限性,恰巧她還是一個行動派,她要做出一個無比龐大的資料集,儘可能描述世界上一切物體的資料集,下載圖片,給沒一張圖片做標註,簡單而無聊,當然後來這項工作放到了亞馬遜的眾包平臺上,全世界無數的人蔘與了這個偉大的專案,到此刻為止,已經有 14,197,122 張圖片(一千四百萬張),21841 個分類。在這個發展的過程中,人們也發現了這個資料集帶來的成功遠比預想的要多,甚至現在被認為最有前景的深度卷積神經網路的提出也與 ImageNet 不無關係。我忘記了誰這麼說過:“就單單這一個資料集,就可以讓李飛飛資料科學這個領域擁有一席之地”。暫且不說這麼說是否準確,但這個資料集仍然在創造新的突破。(我曾經在臺下聽過李飛飛一次演講,現在想想還覺得甚是激動,她真的充滿熱情)。
基於這個資料集,我們是不是可以訓練出一些網路,一般情況下,大家就不用耗時再去訓練網路了呢?答案是肯定的,並且在 Keras 就有個一些這樣的模型,還是內建的,Keras 就是這麼懂你,那就不用客氣了,我們拿來用就好了,謝謝啦!
特徵提取
我們之前用到的卷積神經網路都是分成了兩部分,第一部分是由池化層和卷積層組成的卷積積,第二部分是由分類器,特徵提取的含義就是第一部分不變,改變第二部分。
為什麼可以這麼做?我們之前解釋過神經網路的執行原理,跟人腦的認識過程非常類似,還記得嗎?我們還是看一看原來的圖吧。
我們可以看出來,網路識別影象是有層次結構的,比如一開始的網路層是用來識別影象或者拼裝線條的,這是通用且類似的,因此我們可以複用。而後面的分類器往往是根據具體的問題所決定的,比如識別貓或狗的眼睛就與識別桌子腿是不一樣的,因此有越靠前越具有通用性的特點。Keras 中很多的內建模型都可以直接下載,如果你沒有下載在使用的時候會自動下載:
https://github.com/fchollet/deep-learning-models/releases
我們舉一個例子,用 VGG16 去識別貓或狗,這次的解釋都比較簡單且都是以前說明過的,因此放在程式碼註釋中:
#!/usr/bin/env python3 import os import time import matplotlib.pyplot as plt import numpy as np from keras import layers from keras import models from keras import optimizers from keras.applications import VGG16 from keras.preprocessing.image import ImageDataGenerator def extract_features(directory, sample_count): # 圖片轉換區間 datagen = ImageDataGenerator(rescale=1. / 255) batch_size = 20 conv_base = VGG16(weights='imagenet', include_top=False, input_shape=(150, 150, 3)) conv_base.summary() features = np.zeros(shape=(sample_count, 4, 4, 512)) labels = np.zeros(shape=(sample_count)) # 讀出圖片,處理成神經網路需要的資料格式,上一篇文章中有介紹 generator = datagen.flow_from_directory( directory, target_size=(150, 150), batch_size=batch_size, class_mode='binary') i = 0 for inputs_batch, labels_batch in generator: print(i, '/', len(generator)) # 提取特徵 features_batch = conv_base.predict(inputs_batch) features[i * batch_size: (i + 1) * batch_size] = features_batch labels[i * batch_size: (i + 1) * batch_size] = labels_batch i += 1 if i * batch_size >= sample_count: break # 特徵和標籤 return features, labels def cat(): base_dir = '/Users/renyuzhuo/Desktop/cat/dogs-vs-cats-small' train_dir = os.path.join(base_dir, 'train') validation_dir = os.path.join(base_dir, 'validation') # 提取出的特徵 train_features, train_labels = extract_features(train_dir, 2000) validation_features, validation_labels = extract_features(validation_dir, 1000) # 對特徵進行變形展平 train_features = np.reshape(train_features, (2000, 4 * 4 * 512)) validation_features = np.reshape(validation_features, (1000, 4 * 4 * 512)) # 定義密集連線分類器 model = models.Sequential() model.add(layers.Dense(256, activation='relu', input_dim=4 * 4 * 512)) model.add(layers.Dropout(0.5)) model.add(layers.Dense(1, activation='sigmoid')) # 對模型進行配置 model.compile(optimizer=optimizers.RMSprop(lr=2e-5), loss='binary_crossentropy', metrics=['acc']) # 對模型進行訓練 history = model.fit(train_features, train_labels, epochs=30, batch_size=20, validation_data=(validation_features, validation_labels)) # 畫圖 acc = history.history['acc'] val_acc = history.history['val_acc'] loss = history.history['loss'] val_loss = history.history['val_loss'] epochs = range(1, len(acc) + 1) plt.plot(epochs, acc, 'bo', label='Training acc') plt.plot(epochs, val_acc, 'b', label='Validation acc') plt.title('Training and validation accuracy') plt.legend() plt.show() plt.figure() plt.plot(epochs, loss, 'bo', label='Training loss') plt.plot(epochs, val_loss, 'b', label='Validation loss') plt.title('Training and validation loss') plt.legend() plt.show() if __name__ == "__main__": time_start = time.time() cat() time_end = time.time() print('Time Used: ', time_end - time_start)
有點巧合的是這裡居然看不到太多的過擬合的痕跡,其實也是有可能會有過擬合的隱患的,那樣就需要進行資料增強,與以前是一樣的,只不過這裡的區別就是用到了內建模型,模型的引數需要凍結,我們是不希望對已經訓練好的模型進行更改的,具體關鍵程式碼寫法如下:
conv_base.trainable = False
model = models.Sequential()
model.add(conv_base)
model.add(layers.Flatten())
model.add(layers.Dense(256, activation='relu'))
model.add(layers.Dense(1, activation='sigmoid'))
以上就是模型複用的一種方法,我們對模型都是原封不動的拿來用,我們下一篇文章將介紹另外一種方法,對模型進行微調。
首發自公眾號:RAIS