知識蒸餾--Distilling the Knowledge in a Neural Network
知識蒸餾--Distilling the Knowledge in a Neural Network
動機
- 在普遍的訓練當中,經過 softmax 後都是最大化正標籤的概率,最小化負標籤的概率。但是這樣訓練的效果導致了正標籤的概率輸出越來越接近 1, 負標籤的概率越來越接近 0, 使得原本的負標籤的概率有一些是比其他的大得多,但是這種相對關係在經過多次訓練後他們的概率都是漸漸趨近 0. 所以導致模型輸出概率原本還有大量的資訊丟失。
- 一般認為,用於訓練的目標函式應該是儘可能反應使用者的真實標籤,因為越接近真實標籤那麼精度或者準確率卻高,但是機器學習所需要的是學習它的泛化能力,並不是它的真值標籤。在遇到未曾講過的樣本要能夠進行正確的分類。然而由於條件所限,我們一般把提升模型的泛化能力這個目標簡化為訓練模型在訓練集上對真值標籤的預測能力,我們也認為,訓練得到的模型對真值標籤的預測能力越強,它的泛化能力也應該越強,這也是很合理的。
貢獻
- 提出了知識蒸餾,把大模型對樣本輸出的概率向量作為軟目標“soft targets”,去讓小模型的輸出儘量去和這個軟目標靠。
思想
預備知識
\(softmax\) 函式:
\[q_i = \frac{e^{z_i}}{\sum_j e^{z_j}} \]交叉熵:
\[CE = -\sum_{x \in X} p(x) \log q(x) = H(p) + D_{KL}(p||q) \]交叉熵損失函式:
假設樣本數量 \(n\) ,真實標籤為 T,\(T = \{t_1, t_2, ... , t_c\} \quad \quad t_i \in R^{c}\),預測值為 Y, \(Y = \{y_1,y_2, ..., y_n\}\quad\quad y_i \in R^p\)
對於分類問題,在神經網路最後一層的啟用函式一般是 \(softmax\) 歸一化並且輸出概率向量,並通過最小化交叉熵損失函式進行反向傳播更新引數,假設標籤為 \(one-hot\) 形式, 那麼上式交叉熵損失函式化簡為 \(l_{CE} = -\sum_{i = 1}^{n} \log y_{ik} \quad s.t\quad t_{ik} == 1\) ,最小化損失函式,就是要最大化概率值,既是使真標籤對應概率不斷趨近於 1, 負標籤的概率不斷趨近於 0,最後輸出的概率(目標)就是趨近 \(one-hot\)
與 hard target 對應的就是 soft target,soft target 中分佈的熵相對更高,其蘊含的知識就更加豐富。
那如何才能得到熵相對高的 soft target 呢,由於傳統的 \(softmax\) 如果直接作為 soft target 那麼導致負標籤的概率趨近 0,正標籤的概率趨近 1, 熵就會相對的低,這時候就要引入溫度
\[q_i = \frac{e^{z_i/T}}{\sum_j e^{z_j/T}} \]為什麼引入溫度就可以,可以通過看一下 指數函式 \(y=e^x\) 的曲線。
可以發現,\(Z\) 如果本身就是類似\([-10, 20, 40, -5, -8]\) 這樣的向量經過指數後原本存在的差距將會變得越來越大(正的越來越大,負的越來越趨近 0),導致 softmax 歸一化後的更偏向於 \(one -hot\) 形式,而這也是指數函式達到一個放縮的效果。 而通過引入溫度後,使得 \(Z\) 都趨於集中,讓 softmax 輸出更加平滑(也就是 \(Z\) 在指數函式上更趨於平滑 )。 讓它的分佈的熵更大,負標籤攜帶的資訊會被相對地放大,模型訓練將更加關注負標籤。
框架
teacher 模型就是一個大的複雜模型,效果好,student 模型是一個輕量型的模型,我們的目的是將student 模型經過訓練後達到 teacher 模型的效果,或者比teacher模型更好。對於訓練 student 模型中損失函式主要由兩部分組成,一部分使 teacher 模型經過知識蒸餾後得到 soft loss,在於自己模型普通訓練後的 hard loss。總的loss = soft loss + hard loss。
核心
就是在原來的損失函式進行加權 \(L_{soft}\):
\[L = \alpha L_{soft} + \beta L_{hard} \]其中 \(\alpha 、\beta\) 分別為超引數
\[L_{soft} = -\sum_j^{n}p_j^T \log(q_j^T)\\ p_i^T = \frac{e^{v_i/T}}{\sum_{k= 1}e^{v_k/T}} \\ q_i^T = \frac{e^{z_i/T}}{\sum_{k= 1}e^{z_k/T}} \\ \]\(v_i,z_i\) 分別表示在相同溫度下 teacher,student 網路下輸出的預測的每個值 (logits),還沒有經過 softmax
\[L_{head} = -\sum_{i = 1}^N t_i \log q^1_i \\ p_i^1 = \frac{e^v_i}{\sum_{k= 1}e^{v_k}} \]