1. 程式人生 > >知識蒸餾(Knowledge Distillation)的Pytorch實現以及分析

知識蒸餾(Knowledge Distillation)的Pytorch實現以及分析

       知識蒸餾(Knowledge Distillation)的概念由Hinton大神於2015年在論文《Distilling the Knowledge in a Neural Network》中提出,論文見:https://arxiv.org/abs/1503.02531。此方法的主要思想為:通過結構複雜、計算量大但是效能優秀的教師神經網路,對結構相對簡單、計算量較小的學生神經網路進行指導,以提升學生神經網路的效能。論文中提出了“暗知識”這一概念,即:比如我們在識別一張貓貓的圖片的時候,一個性能良好的神經網路經過softmax變換後的輸出,在一般該向量中代表貓貓的位置會得到一個非常高的值,比如,0.9,而代表其它分類的值在傳統的研究中就不那麼受重視了。Hinton大神認為,其它位置得到的值能夠為我們提供一些額外的資訊,比如,在貓得到0.9的同時,識別為獅子的值可能因為相似的緣故給到了0.09,而識別為汽車的值則可能只有0.0001。在我的理解中,這種目標間的相似性,就是“暗知識”的本質。為了要放大這種“暗知識”所包含的資訊,Hinton在傳統的softmax函式中加入溫度引數T,變為下式所示:

                                                                        

       那麼,知識蒸餾的步驟分別為:

一、採用傳統方式訓練一個教師網路。

二、建立學生網路模型,模型的輸出採用傳統的softmax函式,擬合目標為one-hot形式的訓練集輸出,它們之間的距離記為loss1。

三、將訓練完成的教師網路的softmax分類器加入溫度引數,作為具有相同溫度引數softmax分類器的學生網路的擬合目標,他們之間的距離記為loss2。

四、引入引數alpha,將loss1×(1-alpha)+loss2×alpha作為網路訓練時使用的loss,訓練網路。

       重點就在於將暗知識放大之後,讓學生網路的暗知識去擬合教師網路的暗知識,同時由於教師網路會帶有一定的bias,表現為教師網路在訓練完成後,對訓練集識別的正確率會高於測試集,所以加上loss1來減緩這種趨勢,實際應用的時候,可以考慮將alpha首先設定的接近1,比如0.95啥的,來快速擬合教師網路,再逐步調低alpha的值,來確保網路的分類正確率,不過這只是理論上可行的,我也沒試驗就是了……

       那我們就開搞啦,首先是搭建教師網路,我這裡選擇的是resnet18,並且由於電腦訓練速度的原因(渣機無力……)將網路中所有卷積核的數目減少了一半,訓練集採用Cifar10,訓練時對影象進行了padding之後隨機裁剪以及隨機水平翻轉來加入噪聲。優化器採用帶動量項的SGD(lr=0.1, momentum=0.9, weight_decay=5e-4),訓練200個epoch,其中在第100以及第150個epoch時將學習率除10,詳細的程式碼見文章末尾的github地址好啦。訓練完成後,網路對測試集的識別結果如下所示:

Accuracy of the network on the 10000 test images: 90.820000 %
Accuracy of plane : 97.727273 %
Accuracy of   car : 100.000000 %
Accuracy of  bird : 86.842105 %
Accuracy of   cat : 83.720930 %
Accuracy of  deer : 91.836735 %
Accuracy of   dog : 96.875000 %
Accuracy of  frog : 92.452830 %
Accuracy of horse : 90.625000 %
Accuracy of  ship : 93.750000 %
Accuracy of truck : 82.758621 %

       這結果當然並不算特別好,所以作為學生的網路,得選個效果比較差的,這樣才能體現出教師的價值對吧(笑)。這裡我們就簡單的架一個三層卷積神經網路作為學生網路好啦,網路具體結構見github。還是使用cifar10經過相同的影象變換過程後,採用adam(lr=0.001)作為優化器對網路訓練100個epoch,在完全相同的條件下訓練四次,測試集識別結果分別如下,我們可以看到,這幾次的訓練結果平均一下大概差不多卡在80%的點上,畢竟這玩意有運氣成分在。

第一次訓練結果:
Accuracy of the network on the 10000 test images: 78.590000 %
Accuracy of plane : 81.818182 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 71.052632 %
Accuracy of   cat : 67.441860 %
Accuracy of  deer : 81.632653 %
Accuracy of   dog : 81.250000 %
Accuracy of  frog : 94.339623 %
Accuracy of horse : 84.375000 %
Accuracy of  ship : 89.583333 %
Accuracy of truck : 82.758621 %
第二次訓練結果:
Accuracy of the network on the 10000 test images: 79.990000 %
Accuracy of plane : 88.636364 %
Accuracy of   car : 90.625000 %
Accuracy of  bird : 60.526316 %
Accuracy of   cat : 69.767442 %
Accuracy of  deer : 81.632653 %
Accuracy of   dog : 65.625000 %
Accuracy of  frog : 84.905660 %
Accuracy of horse : 84.375000 %
Accuracy of  ship : 81.250000 %
Accuracy of truck : 93.103448 %
第三次訓練結果:
Accuracy of the network on the 10000 test images: 81.160000 %
Accuracy of plane : 79.545455 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 63.157895 %
Accuracy of   cat : 72.093023 %
Accuracy of  deer : 79.591837 %
Accuracy of   dog : 87.500000 %
Accuracy of  frog : 84.905660 %
Accuracy of horse : 84.375000 %
Accuracy of  ship : 85.416667 %
Accuracy of truck : 89.655172 %
第四次訓練結果:
Accuracy of the network on the 10000 test images: 79.890000 %
Accuracy of plane : 90.909091 %
Accuracy of   car : 87.500000 %
Accuracy of  bird : 76.315789 %
Accuracy of   cat : 65.116279 %
Accuracy of  deer : 71.428571 %
Accuracy of   dog : 71.875000 %
Accuracy of  frog : 92.452830 %
Accuracy of horse : 87.500000 %
Accuracy of  ship : 91.666667 %
Accuracy of truck : 72.413793 %

       接下來,因為之前看到網上有人說,教師網路本身在訓練的時候,是有采用加噪資料進行訓練的,所以它的輸出的暗知識在理論上可能會包含有噪聲項的資訊,我們就先在不對資料集進行變換的情況下進行訓練。這裡我們選取alpha=0.95,T選取2和10分別訓練兩次,結果如下。我們可以看到,其訓練的結果比之前的方法是要差的,這可能是因為學生網路還是直接過擬合了教師網路的輸出,所以導致測試集正確率較低。

T=2第一次訓練結果:
Accuracy of the network on the 10000 test images: 78.440000 %
Accuracy of plane : 90.909091 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 65.789474 %
Accuracy of   cat : 65.116279 %
Accuracy of  deer : 73.469388 %
Accuracy of   dog : 71.875000 %
Accuracy of  frog : 81.132075 %
Accuracy of horse : 81.250000 %
Accuracy of  ship : 81.250000 %
Accuracy of truck : 89.655172 %
T=2第二次訓練結果:
Accuracy of the network on the 10000 test images: 75.830000 %
Accuracy of plane : 88.636364 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 60.526316 %
Accuracy of   cat : 55.813953 %
Accuracy of  deer : 71.428571 %
Accuracy of   dog : 59.375000 %
Accuracy of  frog : 73.584906 %
Accuracy of horse : 81.250000 %
Accuracy of  ship : 85.416667 %
Accuracy of truck : 79.310345 %
T=10第一次訓練結果:
Accuracy of the network on the 10000 test images: 78.300000 %
Accuracy of plane : 88.636364 %
Accuracy of   car : 90.625000 %
Accuracy of  bird : 60.526316 %
Accuracy of   cat : 60.465116 %
Accuracy of  deer : 73.469388 %
Accuracy of   dog : 59.375000 %
Accuracy of  frog : 84.905660 %
Accuracy of horse : 71.875000 %
Accuracy of  ship : 85.416667 %
Accuracy of truck : 93.103448 %
T=10第二次訓練結果:
Accuracy of the network on the 10000 test images: 76.130000 %
Accuracy of plane : 95.454545 %
Accuracy of   car : 87.500000 %
Accuracy of  bird : 78.947368 %
Accuracy of   cat : 60.465116 %
Accuracy of  deer : 63.265306 %
Accuracy of   dog : 62.500000 %
Accuracy of  frog : 84.905660 %
Accuracy of horse : 78.125000 %
Accuracy of  ship : 89.583333 %
Accuracy of truck : 72.413793 %

       最後是對圖片進行了相應的變換加入噪聲後,對學生網路進行訓練,結果如下:

T=2第一次訓練結果:
Accuracy of the network on the 10000 test images: 82.460000 %
Accuracy of plane : 95.454545 %
Accuracy of   car : 90.625000 %
Accuracy of  bird : 78.947368 %
Accuracy of   cat : 69.767442 %
Accuracy of  deer : 81.632653 %
Accuracy of   dog : 65.625000 %
Accuracy of  frog : 88.679245 %
Accuracy of horse : 78.125000 %
Accuracy of  ship : 81.250000 %
Accuracy of truck : 93.103448 %
T=2第二次訓練結果:
Accuracy of the network on the 10000 test images: 80.760000 %
Accuracy of plane : 86.363636 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 71.052632 %
Accuracy of   cat : 81.395349 %
Accuracy of  deer : 81.632653 %
Accuracy of   dog : 75.000000 %
Accuracy of  frog : 81.132075 %
Accuracy of horse : 78.125000 %
Accuracy of  ship : 79.166667 %
Accuracy of truck : 89.655172 %
T=10第一次訓練結果:
Accuracy of the network on the 10000 test images: 81.780000 %
Accuracy of plane : 90.909091 %
Accuracy of   car : 93.750000 %
Accuracy of  bird : 63.157895 %
Accuracy of   cat : 81.395349 %
Accuracy of  deer : 79.591837 %
Accuracy of   dog : 87.500000 %
Accuracy of  frog : 83.018868 %
Accuracy of horse : 87.500000 %
Accuracy of  ship : 95.833333 %
Accuracy of truck : 93.103448 %
T=10第二次訓練結果:
Accuracy of the network on the 10000 test images: 81.470000 %
Accuracy of plane : 88.636364 %
Accuracy of   car : 100.000000 %
Accuracy of  bird : 60.526316 %
Accuracy of   cat : 60.465116 %
Accuracy of  deer : 81.632653 %
Accuracy of   dog : 75.000000 %
Accuracy of  frog : 75.471698 %
Accuracy of horse : 87.500000 %
Accuracy of  ship : 93.750000 %
Accuracy of truck : 93.103448 %

       雖然測試集的正確率具有一定程度的不確定性,我們還是可以看出,測試集正確率相比原始的訓練方法平均提升了約1.5%。這也可以大致說明這種方法的有效性。當然,這種訓練方式目前也產生了很多的變體,比如再生網路等等、

       最後是相關程式與訓練完成的網路引數檔案的github地址:https://github.com/PolarisShi/distillation