快速入門深度學習(2)遷移學習
咱們繼續入門課程系列,這次是關於遷移學習(Transfer Learning)的故事。
這次咱們要“學習”一把了,針對特定的任務構造自己的分類器了。這次咱們仍然使用AlexNet的網路結構(誰讓它經典呢),訓練這個網路讓它為咱們服務。
在正式Coding之前,首先了解下什麼是遷移學習。所謂的遷移學習就是指在深度學習中,把一個學習好的深度網路,稍加改造變成自己特有網路的意思,至於這樣做的道理,咱們這裡先不深入探討,只要先記住遷移學習有個很大的好處,就是網路收斂速度快。
實驗準備
Matlab2017b或者更新的版本,AlexNet。
資料準備:為了實驗的一致性,使用Matlab計算機視覺工具箱自帶的資料。
開始程式設計
載入資料
unzip('MerchData.zip');
imds = imageDatastore('MerchData','IncludeSubfolders',true,'LabelSource','foldernames');
[imdsTrain,idmsValidation] =splitEachLabel(imds,0.7,'randomized');
unzip函式的意思是解壓壓縮檔案。執行這一句之後可以看到在當前目錄下多了一個資料夾:
這個資料夾裡面就是本次實驗所使用的資料。為了更方便地組織該資料,我們使用imageDatastore函式來構造一個數據結構,用以管理資料。執行上面一句之後得到來一個imageDatastore資料結構,我們進入當前的工作空間對其進行觀察。
可以看到待使用的資料,被一個數據結構進行了組織,並且使用資料夾的名稱作為了類標籤。我們隨機選擇16個影象用 的方式進行顯示。
numImages= numel(imds.Labels);%統計總數
idx =randperm(numImages,16); %隨機選擇
figure
for i = 1:16
subplot(4,4,i)
I = readimage(imds,idx(i));
imshow(I)
end
可以看到
我們接下來把影象分為測試集(30%)和訓練集(70%):
[imdsTrain,idmsValidation]= splitEachLabel(imds,0.7,'randomized'
資料準備完畢了。
載入AlexNet網路
由於我們這一章講的是遷移學習,所以接下來需要載入已經訓練好的alexnet網路。關於如何載入請參看前一章。
net =alexnet;
修改網路
由於咱們這次只需要識別5個類,所以需要對AlexNet網路進行修改以適應當前的問題。我們這次主要對其進行如下修改:修改全連線層的輸出數量,從原來的1000變為5,其餘保持不變。首先提取出前面的層數,然後使用fullyConnectedLayer構造全連線層,最後完成整個網路的構建。
layersTransfer= net.Layers(1:end-3);
layers =[
layersTransfer
fullyConnectedLayer(5,'WeightLearnRateFactor',20,'BiasLearnRateFactor',20)
softmaxLayer
classificationLayer];
最後的結果layer就是我們需要的網路結構,此時網路還未經訓練。
訓練網路
訓練網路在Matlab中是一件非常簡單的事情,我們只需要配置好訓練引數就好了:
options = trainingOptions('sgdm',...
'MiniBatchSize',10, ...
'MaxEpochs',6, ...
'InitialLearnRate',1e-4, ...
'ValidationData',idmsValidation, ...
'ValidationFrequency',3, ...
'ValidationPatience',Inf, ...
'Verbose',false, ...
'Plots','training-progress');
關於訓練的引數,咱們以後再詳細介紹,這裡需要了解的一點就是,由於神經網路引數眾多,而且是一個典型的非凸優化問題,所以,訓練的引數選擇相當重要。
netTransfer = trainNetwork(imdsTrain,layers,options);
執行完上面一句就可以得到netTransfer作為遷移網路。
驗證網路
我們使用驗證集去測試神經網路的有效性:
YPred = classify(netTransfer,idmsValidation);
accuracy = mean(YPred == idmsValidation.Labels)
結果表明我們的訓練出來的神經網路具有良好的泛化性。
總結
從上面的程式設計過程中,可以發現Matlab神經網路工具箱已經幫助我們做好了很多工作,我們只需要去設計網路即可,然後訓練即可,把廣大程式設計師從無邊無際的codeing中解放出來。