Distilling the Knowledge in a Neural Network
阿新 • • 發佈:2020-10-26
目錄
很有可能和\(3\)長的比較像, 這是one-hot無法帶來的資訊.
概
\[q_1 = \frac{\exp(z_i/T)}{\sum_j \exp(z_j/T)}. \]主要內容
這篇文章或許重點是在遷移學習上, 一個重點就是其認為soft labels (即概率向量)比hard target (one-hot向量)含有更多的資訊. 比如, 數字模型判別數字\(2\)為\(3\)和\(7\)的概率分別是0.1, 0.01, 這說明這個數字\(2\)
於是乎, 現在的情況是:
-
以及有一個訓練好的且往往效果比較好但是計量大的模型\(t\);
-
我們打算用一個小的模型\(s\)去近似這個已有的模型;
-
策略是每個樣本\(x\), 先根據\(t(x)\)獲得soft logits \(z \in \mathbb{R}^K\), 其中\(K\)是類別數, 且\(z\)未經softmax.
-
最後我們希望根據下面的損失函式來訓練\(s\):
\[\mathcal{L(x, y)} = T^2 \cdot \mathcal{L}_{soft}(x, y) + \lambda \cdot\mathcal{L}_{hard}(x, y) \]
其中
\[\mathcal{L}_{soft}(x, y) = -\sum_{i=1}^K p_i(x) \log q_i (x) = -\sum_{i=1}^K \frac{\exp(v_i(x)/T)}{\sum_j \exp(v_j(x)/T)} \log \frac{\exp(z_i(x)/T)}{\sum_j \exp(z_j(x)/T)} \]\[\mathcal{L}_{hard}(x, y) = -\log \frac{\exp(z_y(x))}{\sum_j \exp(z_j(x))} \]至於\(T^2\)是怎麼來的, 這是為了配平梯度的magnitude.
\[\begin{array}{ll} \frac{\partial \mathcal{L}_{soft}}{\partial z_k} &= -\sum_{i=1}^K \frac{p_i}{q_i} \frac{\partial q_i}{\partial z_k} = -\frac{1}{T}p_k -\sum_{i=1}^K \frac{p_i}{q_i} \cdot (-\frac{1}{T}q_i q_k) \\ &= -\frac{1}{T} (p_k -\sum_{i=1}^K p_iq_k) = \frac{1}{T}(q_k-p_k) \\ &= \frac{1}{T} (\frac{e^{z_i/T}}{\sum_j e^{z_j/T}} - \frac{e^{v_i/T}}{\sum_j e^{v_j/T}}) . \end{array} \]當\(T\)足夠大的時候, 並假設\(\sum_j z_j=0 = \sum_j v_j =0\), 有
\[\frac{\partial \mathcal{L}_{soft}}{\partial z_k} \approx \frac{1}{KT^2} (z_k - v_k). \]故需要加個\(T^2\)取抵消這部分的影響.
程式碼
其實一直很好奇的一點是這部分程式碼在pytorch裡是怎麼實現的, 畢竟pytorch裡的交叉熵是
\[-\log p_y(x) \]另外很噁心的一點是, 我看大家都用的是 KLDivLOSS, 但是其實現居然是:
\[\mathcal{L}(x, y) = y \cdot \log y - y \cdot x, \]注: 這裡的\(\cdot\)是逐項的.
def kl_div(x, y):
return y * (torch.log(y) - x)
x = torch.randn(2, 3)
y = torch.randn(2, 3).abs() + 1
loss1 = F.kl_div(x, y, reduction="none")
loss2 = kl_div(x, y)
這時, 出來的結果長這樣
tensor([[-1.5965, 2.2040, -0.8753],
[ 3.9795, 0.0910, 1.0761]])
tensor([[-1.5965, 2.2040, -0.8753],
[ 3.9795, 0.0910, 1.0761]])
又或者:
def kl_div(x, y):
return (y * (torch.log(y) - x)).sum(dim=1).mean()
torch.manual_seed(10086)
x = torch.randn(2, 3)
y = torch.randn(2, 3).abs() + 1
loss1 = F.kl_div(x, y, reduction="batchmean")
loss2 = kl_div(x, y)
print(loss1)
print(loss2)
tensor(2.4394)
tensor(2.4394)
所以如果真要弄, 應該要
def soft_loss(z, v, T=10.):
# z: logits
# v: targets
z = F.log_softmax(z / T, dim=1)
v = F.softmax(v / T, dim=1)
return F.kl_div(z, v, reduction="batchmean")