1. 程式人生 > 實用技巧 >知識蒸餾

知識蒸餾

知識蒸餾

一. Distilling the Knowledge in a Neural Network

知識蒸餾的開端之作,簡單敘述蒸餾過程:

  • 先訓練一個大網路,比如Resnet50用於分類任務
  • 搭建一個小網路訓練結構,比如mobilenetV2
  • 訓練小網路的同時推理大網路,大網路的結果去指導小網路(KDLoss用於估計分佈的相似性)

類似的程式碼:連結地址

類似的文章:連結地址

比較簡單的過程:

# 教師輸出和學生輸出得到loss1,學生輸出和label得到loss2,按一定比例結合進行反向傳播
def loss_fn_kd(outputs, labels, teacher_outputs, params):
    """
    Compute the knowledge-distillation (KD) loss given outputs, labels.
    "Hyperparameters": temperature and alpha
    NOTE: the KL Divergence for PyTorch comparing the softmaxs of teacher
    and student expects the input tensor to be log probabilities! See Issue #2
    """
    alpha = params.alpha
    T = params.temperature
    KD_loss = nn.KLDivLoss()(F.log_softmax(outputs/T, dim=1),
                             F.softmax(teacher_outputs/T, dim=1)) * (alpha * T * T) + \
              F.cross_entropy(outputs, labels) * (1. - alpha)

    return KD_loss

二. Fast Human Pose Estimation Pytorch

論文:連結地址

程式碼:連結地址

論文沒有實質的創新,KDLoss直接使用MSE對Heatmap進行分佈相似估計,正常Loss也使用MSE,按一定比例結核即可

註釋:看見有人說蒸餾必須網路結構類似,不然效果反而會下降(待嘗試)

# 關鍵點不可見的情況下只進行KDLoss,可見的情況下進行KDLoss和正常訓練Loss
for j in range(0, len(output)):
  	_output = output[j]
  	for i in range(gtmask.shape[0]):
          if gtmask[i] < 0.1:
          # unlabeled data, gtmask=0.0, kdloss only
          # need to dividen train_batch to keep number equal
          kdloss_unlabeled += criterion(_output[i,:,:,:], toutput[i, :,:,:])/train_batch
      	else:
          # labeled data: kdloss + gtloss
          gtloss += criterion(_output[i,:,:,:], target_var[i, :,:,:])/train_batch
          kdloss += criterion(_output[i,:,:,:], toutput[i,:,:,:])/train_batch

loss_labeled = kdloss_alpha * (kdloss) + (1 - kdloss_alpha)*gtloss
total_loss   = loss_labeled + unkdloss_alpha * kdloss_unlabeled