知識蒸餾
阿新 • • 發佈:2020-12-16
知識蒸餾
一. Distilling the Knowledge in a Neural Network
知識蒸餾的開端之作,簡單敘述蒸餾過程:
- 先訓練一個大網路,比如Resnet50用於分類任務
- 搭建一個小網路訓練結構,比如mobilenetV2
- 訓練小網路的同時推理大網路,大網路的結果去指導小網路(KDLoss用於估計分佈的相似性)
類似的程式碼:連結地址
類似的文章:連結地址
比較簡單的過程:
# 教師輸出和學生輸出得到loss1,學生輸出和label得到loss2,按一定比例結合進行反向傳播 def loss_fn_kd(outputs, labels, teacher_outputs, params): """ Compute the knowledge-distillation (KD) loss given outputs, labels. "Hyperparameters": temperature and alpha NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher and student expects the input tensor to be log probabilities! See Issue #2 """ alpha = params.alpha T = params.temperature KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1), F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \ F.cross_entropy(outputs, labels) * (1. - alpha) return KD_loss
二. Fast Human Pose Estimation Pytorch
論文:連結地址
程式碼:連結地址
論文沒有實質的創新,KDLoss直接使用MSE對Heatmap進行分佈相似估計,正常Loss也使用MSE,按一定比例結核即可
註釋:看見有人說蒸餾必須網路結構類似,不然效果反而會下降(待嘗試)
# 關鍵點不可見的情況下只進行KDLoss,可見的情況下進行KDLoss和正常訓練Loss for j in range(0, len(output)): _output = output[j] for i in range(gtmask.shape[0]): if gtmask[i] < 0.1: # unlabeled data, gtmask=0.0, kdloss only # need to dividen train_batch to keep number equal kdloss_unlabeled += criterion(_output[i,:,:,:], toutput[i, :,:,:])/train_batch else: # labeled data: kdloss + gtloss gtloss += criterion(_output[i,:,:,:], target_var[i, :,:,:])/train_batch kdloss += criterion(_output[i,:,:,:], toutput[i,:,:,:])/train_batch loss_labeled = kdloss_alpha * (kdloss) + (1 - kdloss_alpha)*gtloss total_loss = loss_labeled + unkdloss_alpha * kdloss_unlabeled