1. 程式人生 > 其它 >知識蒸餾是什麼?一份入門隨筆

知識蒸餾是什麼?一份入門隨筆

0. 寫在前面

有人說過:“神經網路用剩的logits不要扔,沾上雞蛋液,裹上面包糠...” 這兩天對知識蒸餾(Knowledge Distillation)萌生了一點興趣,正好寫一篇文章分享一下。這篇文章姑且算是一篇小科普。

前排小廣告:如果覺得文章對你有幫助,歡迎點贊、關注我的小專欄ML4NLP

1. 從模型壓縮開始

各種模型演算法,最終目的都是要為某個應用服務。在買賣中,我們需要控制收入和支出。類似地,在工業應用中,除了要求模型要有好的預測(收入)以外,往往還希望它的「支出」要足夠小。具體來說,我們一般希望部署到應用中的模型使用較少的計算資源(儲存空間、計算單元等),產生較低的時延。

在深度學習的背景下,為了達到更好的預測,常常會有兩種方案:1. 使用過引數化的深度神經網路,這類網路學習能力非常強,因此往往加上一定的正則化策略(如dropout);2. 整合模型(ensemble),將許多弱的模型整合起來,往往可以實現較好的預測。這兩種方案無疑都有較大的「支出」,需要的計算量和計算資源很大,對部署非常不利。這也就是模型壓縮的動機:我們希望有一個規模較小的模型,能達到和大模型一樣或相當的結果。當然,從頭訓練一個小模型,從經驗上看是很難達到上述效果的,也許我們能先訓練一個大而強的模型,然後將其包含的知識轉移給小的模型呢?如何做到呢?

* 下文統一將要訓練的小模型稱為新模型,將以及訓練的大模型稱為原模型。

Rich Caruana等人在[1]中指出,可以讓新模型近似(approximate)原模型(模型即函式)。注意到,在機器學習中,我們常常假定輸入到輸出有一個潛在的函式關係,這個函式是未知的:從頭學習一個新模型就是從有限的資料中近似一個未知的函式。如果讓新模型近似原模型,因為原模型的函式是已知的,我們可以使用很多非訓練集內的偽資料來訓練新模型,這顯然要更可行。

這樣,原來我們需要讓新模型的softmax分佈與真實標籤匹配,現在只需要讓新模型與原模型在給定輸入下的softmax分佈匹配了。直觀來看,後者比前者具有這樣一個優勢:經過訓練後的原模型,其softmax分佈包含有一定的知識——真實標籤只能告訴我們,某個影象樣本是一輛寶馬,不是一輛垃圾車,也不是一顆蘿蔔;而經過訓練的softmax可能會告訴我們,它最可能是一輛寶馬,不大可能是一輛垃圾車,但絕不可能是一顆蘿蔔[2]。

2. 為什麼叫「蒸餾」?

接續前面的討論,我們的目標是讓新模型與原模型的softmax輸出的分佈充分接近。直接這樣做是有問題的:在一般的softmax函式中,自然指數  先拉大logits之間的差距,然後作歸一化,最終得到的分佈是一個arg max的近似 ,其輸出是一個接近one-hot的向量,其中一個值很大,其他的都很小。這種情況下,前面說到的「可能是垃圾車,但絕不是蘿蔔」這種知識的體現是非常有限的。相較類似one-hot這樣的硬性輸出,我們更希望輸出更「軟」一些。

一種方法是直接比較logits來避免這個問題。具體地,對於每一條資料,記原模型產生的某個logits是  ,新模型產生的logits是  ,我們需要最小化

文獻[2]提出了更通用的一種做法。考慮一個廣義的softmax函式

其中  是溫度,這是從統計力學中的玻爾茲曼分佈中借用的概念。容易證明,當溫度  趨向於0時,softmax輸出將收斂為一個one-hot向量(證明可以參考我之前的文章:淺談Softmax函式,將  替換為  即可);溫度  趨向於無窮時,softmax的輸出則更「軟」。因此,在訓練新模型的時候,可以使用較高的  使得softmax產生的分佈足夠軟,這時讓新模型(同樣溫度下)的softmax輸出近似原模型;在訓練結束以後再使用正常的溫度  來預測。具體地,在訓練時我們需要最小化兩個分佈的交叉熵(Cross-entropy),記新模型利用公式  產生的分佈是  ,原模型產生的分佈是  ,則我們需要最小化

在化學中,蒸餾是一個有效的分離沸點不同的組分的方法,大致步驟是先升溫使低沸點的組分汽化,然後降溫冷凝,達到分離出目標物質的目的。在前面提到的這個過程中,我們先讓溫度  升高,然後在測試階段恢復「低溫」,從而將原模型中的知識提取出來,因此將其稱為是蒸餾,實在是妙。

當然,如果轉移時使用的是有標籤的資料,那麼也可以將標籤與新模型softmax分佈的交叉熵加入到損失函式中去。這裡需要將式  乘上一個  ,這是為了讓損失函式的兩項的梯度大致在一個數量級上(參考公式  ),實驗表明這將大大改善新模型的表現(考慮到加入了更多的監督訊號)。

3. 與直接優化logits差異相比

由公式  ,對於交叉熵損失來說,其對於新模型的某個logit  的梯度是

由於  與  是等價無窮小( 時),易知,當  充分大時,有

假設所有logits對每個樣本都是零均值化的,即  ,則有

所以,如果:1.  非常大,2. logits對所有樣本都是零均值化的,則知識蒸餾和最小化logits的平方差(公式  )是等價的(因為梯度大致是同一個形式)。實驗表明,溫度  不能取太大,而應該使用某個適中的值,這表明忽略極負的logits對新模型的表現很有幫助(較低的溫度產生的分佈比較「硬」,傾向於忽略logits中極小的負值)。

4. 實驗與結論

Hinton等人做了三組實驗,其中兩組都驗證了知識蒸餾方法的有效性。在MNIST資料集上的實驗表明,即便有部分類別的樣本缺失,新模型也可以表現得很不錯,只需要修改相應的偏置項,就可以與原模型表現相當。在語音任務的實驗也表明,蒸餾得到的模型比從頭訓練的模型捕捉了更多資料集中的有效資訊,表現僅比整合模型低了0.3個百分點。總體來說知識蒸餾是一個簡單而有效的模型壓縮/訓練方法。這大體上是因為原模型的softmax提供了比one-hot標籤更多的監督訊號[3]。

知識蒸餾在後續也有很多延伸工作。在NLP方面比較有名的有Yoon Kim等人的Sequence-Level Knowledge Distillation 等。總的來說,對一些比較臃腫、不便部署的模型,可以將其「知識」轉移到小的模型上。比如,在機器翻譯中,一般的模型需要有較大的容量(capacity)才可能獲得較好的結果;現在非常流行的BERT及其變種,規模都非常大;更不用提,一些情形下我們需要將這些本身就很大的深度模型整合為一個ensemble,這時候,可以用知識蒸餾壓縮出一個較小的、「便宜」的模型。

另外,在多工的情境下,使用一般的策略訓練一個多工模型,可能達不到比單任務更好的效果,文獻[3]探索了使用知識蒸餾,利用單任務的模型來指導訓練多工模型的方法,很值得參考。

補充

鑑於評論區有知友對公式  有疑問,簡單補充一下這裡梯度的推導(其實就是交叉熵損失對softmax輸入的梯度,LOL)。

* 這部分有一點繁瑣,能接受公式  的讀者可以跳過。

由鏈式法則,有

注意到  是原模型產生的softmax輸出,與  無關。

後一項  比較容易得到,因為  ,所以

則  是一個  維向量

前一項  是一個  的方陣,分類討論可以得到。參考公式  ,記  ,由除法的求導法則,輸出元素  對輸入  的偏導是

注意上面右側加方框部分,可以進一步展開

這樣,代入公式  ,並且將括號展開,可以得到

左側方框內偏導可以分類討論得到

帶入式  ,得到

所以  形式如下

代入式  ,可得

 

所以有公式  ,  。

 

參考

[1] Caruana et al., Model Compression, 2006

[2] Hinton et al., Distilling the Knowledge in a Neural Network, 2015

[3] Kevin Clark et al., BAM! Born-Again Multi-Task Networks for Natural Language Understanding