1. 程式人生 > >一分鐘帶你認識深度學習中的知識蒸餾

一分鐘帶你認識深度學習中的知識蒸餾

摘要:知識蒸餾(knowledge distillation)是模型壓縮的一種常用的方法

一、知識蒸餾入門

1.1 概念介紹

知識蒸餾(knowledge distillation)是模型壓縮的一種常用的方法,不同於模型壓縮中的剪枝和量化,知識蒸餾是通過構建一個輕量化的小模型,利用效能更好的大模型的監督資訊,來訓練這個小模型,以期達到更好的效能和精度。最早是由Hinton在2015年首次提出並應用在分類任務上面,這個大模型我們稱之為teacher(教師模型),小模型我們稱之為Student(學生模型)。來自Teacher模型輸出的監督資訊稱之為knowledge(知識),而student學習遷移來自teacher的監督資訊的過程稱之為Distillation(蒸餾)。

1.2 知識蒸餾的種類

圖1 知識蒸餾的種類

1、 離線蒸餾

離線蒸餾方式即為傳統的知識蒸餾,如上圖(a)。使用者需要在已知資料集上面提前訓練好一個teacher模型,然後在對student模型進行訓練的時候,利用所獲取的teacher模型進行監督訓練來達到蒸餾的目的,而且這個teacher的訓練精度要比student模型精度要高,差值越大,蒸餾效果也就越明顯。一般來講,teacher的模型引數在蒸餾訓練的過程中保持不變,達到訓練student模型的目的。蒸餾的損失函式distillation loss計算teacher和student之前輸出預測值的差別,和student的loss加在一起作為整個訓練loss,來進行梯度更新,最終得到一個更高效能和精度的student模型。

2、 半監督蒸餾

半監督方式的蒸餾利用了teacher模型的預測資訊作為標籤,來對student網路進行監督學習,如上圖(b)。那麼不同於傳統離線蒸餾的方式,在對student模型訓練之前,先輸入部分的未標記的資料,利用teacher網路輸出標籤作為監督資訊再輸入到student網路中,來完成蒸餾過程,這樣就可以使用更少標註量的資料集,達到提升模型精度的目的。

3、 自監督蒸餾

自監督蒸餾相比於傳統的離線蒸餾的方式是不需要提前訓練一個teacher網路模型,而是student網路本身的訓練完成一個蒸餾過程,如上圖(c)。具體實現方式 有多種,例如先開始訓練student模型,在整個訓練過程的最後幾個epoch的時候,利用前面訓練的student作為監督模型,在剩下的epoch中,對模型進行蒸餾。這樣做的好處是不需要提前訓練好teacher模型,就可以變訓練邊蒸餾,節省整個蒸餾過程的訓練時間。

1.3 知識蒸餾的功能

1、提升模型精度

使用者如果對目前的網路模型A的精度不是很滿意,那麼可以先訓練一個更高精度的teacher模型B(通常引數量更多,時延更大),然後用這個訓練好的teacher模型B對student模型A進行知識蒸餾,得到一個更高精度的模型。

2、降低模型時延,壓縮網路引數

使用者如果對目前的網路模型A的時延不滿意,可以先找到一個時延更低,引數量更小的模型B,通常來講,這種模型精度也會比較低,然後通過訓練一個更高精度的teacher模型C來對這個引數量小的模型B進行知識蒸餾,使得該模型B的精度接近最原始的模型A,從而達到降低時延的目的。

3、圖片標籤之間的域遷移

使用者使用狗和貓的資料集訓練了一個teacher模型A,使用香蕉和蘋果訓練了一個teacher模型B,那麼就可以用這兩個模型同時蒸餾出一個可以識別狗,貓,香蕉以及蘋果的模型,將兩個不同與的資料集進行整合和遷移。

圖2 影象域遷移訓練

4、降低標註量

該功能可以通過半監督的蒸餾方式來實現,使用者利用訓練好的teacher網路模型來對未標註的資料集進行蒸餾,達到降低標註量的目的。

1.4 知識蒸餾的原理

圖3 知識蒸餾原理介紹

一般使用蒸餾的時候,往往會找一個引數量更小的student網路,那麼相比於teacher來說,這個輕量級的網路不能很好的學習到資料集之前隱藏的潛在關係,如上圖所示,相比於one hot的輸出,teacher網路是將輸出的logits進行了softmax,更加平滑的處理了標籤,即將數字1輸出成了0.6(對1的預測)和0.4(對0的預測)然後輸入到student網路中,相比於1來說,這種softmax含有更多的資訊。好模型的目標不是擬合訓練資料,而是學習如何泛化到新的資料。所以蒸餾的目標是讓student學習到teacher的泛化能力,理論上得到的結果會比單純擬合訓練資料的student要好。另外,對於分類任務,如果soft targets的熵比hard targets高,那顯然student會學習到更多的資訊。最終student模型學習的是teacher模型的泛化能力,而不是“過擬合訓練資料”

二、動手實踐知識蒸餾

ModelArts模型市場中的efficientDet目標檢測演算法目前已經支援知識蒸餾,使用者可以通過下面的一個案例,來入門和熟悉知識蒸餾在檢測網路中的使用流程。

2.1 準備資料集

資料集使用kaggle公開的Images of Canine Coccidiosis Parasite的識別任務,下載地址:https://www.kaggle.com/kvinicki/canine-coccidiosis。使用者下載資料集之後,釋出到ModelArts的資料集管理中,同時進行資料集切分,預設按照8:2的比例切分成train和eval兩種。

2.2 訂閱市場演算法efficientDet

進到模型市場演算法介面,找到efficientDet演算法,點選“訂閱”按鈕

圖4 市場訂閱efficientDet演算法

然後到演算法管理介面,找到已經訂閱的efficientDet,點選同步,就可以進行演算法訓練

圖5 演算法管理同步訂閱演算法

2.3 訓練student網路模型

起一個efficientDet的訓練作業,model_name=efficientdet-d0,資料集選用2.1釋出的已經切分好的資料集,選擇好輸出路徑,點選建立,具體建立引數如下:

圖6 建立student網路的訓練作業

得到訓練的模型精度資訊在評估結果介面,如下:

圖7 student模型訓練結果

可以看到student的模型精度在0.8473。

2.4 訓練teacher網路模型

下一步就是訓練一個teacher模型,按照efficientDet文件的描述,這裡選擇efficientdet-d3,同時需要新增一個引數,表明該訓練作業生成的模型是用來作為知識蒸餾的teacher模型,新起一個訓練作業,具體引數如下:

圖8 teacher模型訓練作業引數

得到的模型精度在評估結果一欄,具體如下:

圖9 teacher模型訓練結果

可以看到teacher的模型精度在0.875。

2.5 使用知識蒸餾提升student模型精度

有了teacher網路,下一步就是進行知識蒸餾了,按照官方文件,需要填寫teacher model url,具體填寫的內容就是2.4訓練輸出路徑下面的model目錄,注意需要選到model目錄的那一層級,同時需要新增引數use_offline_kd=True,具體模型引數如下所示:

圖10 採用知識蒸餾的student模型訓練作業引數

得到模型精度在評估結果一欄,具體如下:

圖11 使用知識蒸餾之後的student模型訓練結果

可以看到經過知識蒸餾之後的student的模型精度提升到了0.863,精度相比於之前的student網路提升了1.6%百分點。

2.6 線上推理部署

訓練之後的模型就可以進行模型部署了,具體點選“建立模型”

 

圖12 建立模型

介面會自動讀取模型訓練的儲存路徑,點選建立:

圖13 匯入模型

模型部署成功之後,點選建立線上服務:

圖14 部署線上服務

部署成功就可以進行線上預測了:

圖15 模型推理結果展示

三、知識蒸餾目前的應用領域

目前知識蒸餾的演算法已經廣泛應用到影象語義識別,目標檢測等場景中,並且針對不同的研究場景,蒸餾方法都做了部分的定製化修改,同時,在行人檢測,人臉識別,姿態檢測,影象域遷移,視訊檢測等方面,知識蒸餾也是作為一種提升模型效能和精度的重要方法,隨著深度學習的發展,這種技術也會更加的成熟和穩定。

參考文獻:

[1]Data Distillation: Towards Omni-Supervised Learning

[2]On the Efficacy of Knowledge Distillation

[3]Knowledge Distillation and Student-Teacher Learning for Visual Intelligence: A Review and New Outlooks

[4]Towards Understanding Knowledge Distillation

[5]Model Compression via Distillation and Quantization

 

點選關注,第一時間瞭解華為雲新鮮技