1. 程式人生 > 其它 >模型蒸餾工作 & logit

模型蒸餾工作 & logit

接上一篇文章:

https://www.cnblogs.com/charlesblc/p/15965479.html

今天重點看這篇文章:

https://mp.weixin.qq.com/s/tKfHq49heakvjM0EVQPgHw

Distilled BiLSTM/BERT-PKD/DistillBERT/TinyBERT/MobileBERT/MiniLM六大經典模型

在蒸餾的過程中,我們將原始大模型稱為教師模型(teacher),新的小模型稱為學生模型(student),訓練集中的標籤稱為hard label,教師模型預測的概率輸出為soft label,temperature(T)是用來調整soft label的超引數。

蒸餾這個概念之所以work,核心思想是因為好模型的目標不是擬合訓練資料,而是學習如何泛化到新的資料。所以蒸餾的目標是讓學生模型學習到教師模型的泛化能力,理論上得到的結果會比單純擬合訓練資料的學生模型要好。

那對於簡單的二分類任務來說,直接拿教師預測的0/1結果會與訓練集差不多,沒什麼意義,那拿概率值是不是好一些?於是Hinton採用了教師模型的輸出概率q,同時為了更好地控制輸出概率的平滑程度,給教師模型的softmax中加了一個引數T。

有了教師模型的輸出後,學生模型的目標就是儘可能擬合教師模型的輸出,新loss就變成了:

因為在求梯度時新的目標函式會導致梯度是以前的,所以要再乘上,不然T變了的話hard label不減小(T=1),但soft label會變。

有同學可能會疑惑:如果可以擬合prob,那直接擬合logits可以嗎?

當然可以,Hinton在論文中進行了證明,如果T很大,且logits分佈的均值為0時,優化概率交叉熵和logits的平方差是等價的。

BERT蒸餾

在BERT提出後,如何瘦身就成了一個重要分支。主流的方法主要有剪枝、蒸餾和量化。量化的提升有限,因此免不了採用剪枝+蒸餾的融合方法來獲取更好的效果。

從各個研究看來,蒸餾的提升一方面來源於從精調階段蒸餾->預訓練階段蒸餾,另一方面則來源於蒸餾最後一層知識->蒸餾隱層知識->蒸餾注意力矩陣

DistillBERT (NIPS2019)

最終損失函式由MLM loss、教師-學生最後一層的交叉熵、隱層之間的cosine loss組成

。從消融實驗可以看出,MLM loss對於學生模型的表現影響較小,同時初始化也是影響效果的重要因素:

TinyBERT(EMNLP2019)

TinyBERT[5]就提出了two-stage learning框架,分別在預訓練和精調階段蒸餾教師模型,得到了引數量減少7.5倍,速度提升9.4倍的4層BERT,效果可以達到教師模型的96.8%,同時這種方法訓出的6層模型甚至接近BERT-base,超過了BERT-PKD和DistillBERT。

TinyBERT的教師模型採用BERT-base。作者參考其他研究的結論,即注意力矩陣可以捕獲到豐富的知識,提出了注意力矩陣的蒸餾,採用教師-學生注意力矩陣logits的MSE作為損失函式(這裡不取attention prob是實驗表明前者收斂更快)。另外,作者還對embedding進行了蒸餾,同樣是採用MSE作為損失。

BERT蒸餾技巧

介紹了BERT蒸餾的幾個經典模型之後,真正要上手前還是要把幾個問題都考慮清楚,下面就來討論一些蒸餾中的變數。

剪層還是減維度?

這個選擇取決於是預訓練蒸餾還是精調蒸餾。預訓練蒸餾的資料比較充分,可以參考MiniLM、MobileBERT或者TinyBERT那樣進行剪層+維度縮減。

對於針對某項任務、只想蒸餾精調後BERT的情況,則推薦進行剪層,同時利用教師模型的層對學生模型進行初始化。從BERT-PKD以及DistillBERT的結論來看,採用skip(每隔n層選一層)的初始化策略會優於只選前k層或後k層。

T和如何設定?

超引數主要控制soft label和hard label的loss比例,Distilled BiLSTM在實驗中發現只使用soft label會得到最好的效果。個人建議讓soft label佔比更多一些,一方面是強迫學生更多的教師知識,另一方面實驗證實soft target可以起到正則化的作用,讓學生模型更穩定地收斂。

超引數T主要控制預測分佈的平滑程度,TinyBERT實驗發現T=1更好,BERT-PKD的搜尋空間則是{5, 10, 20}。因此建議在1~20之間多嘗試幾次,T越大越能學到teacher模型的泛化資訊。比如MNIST在對2的手寫圖片分類時,可能給2分配0.9的置信度,3是1e-6,7是1e-9,從這個分佈可以看出2和3有一定的相似度,這種時候可以調大T,讓概率分佈更平滑,展示teacher更多的泛化能力。