1. 程式人生 > 實用技巧 >利用resnet預訓練權重,出現“bn1.num_batches_tracked”或者“layer.0.bn1.num_batches_tracked" 的解決辦法

利用resnet預訓練權重,出現“bn1.num_batches_tracked”或者“layer.0.bn1.num_batches_tracked" 的解決辦法

報錯的原因在於Pytorch0.4之後,在BN層後新增加了track_running_stats這個引數。

在呼叫預訓練引數模型是,官方給定的預訓練模型是在pytorch0.4之前,因此,呼叫預訓練引數時,需要過濾掉“num_batches_tracked”。

以resnet50為例:

為了載入不同層的權重,採用兩個函式,如下:load_partial_param用於載入layer1, layer2, layer3, layer4的權重權重,load_specific_param用於載入第一層的權重引數。

為了避免“num_batches_tracked”報錯,採用下面的程式碼即可,更改部分為紅色字型(方法簡單,但可以滿足要求)。

 def load_partial_param(self, state_dict, model_index, model_path):
     param_dict = torch.load(model_path)
     for i in state_dict:
         key = 'layer{}.'.format(model_index)+i
         if 'tracked' in key[-7:]:
             continue
         state_dict[i].copy_(param_dict[key])
     del param_dict

def load_specific_param(self, state_dict, param_name, model_path):
    param_dict = torch.load(model_path)
    for i in state_dict:
        key = param_name + '.' + i
        if 'num_batches_tracked' in key:
            continue
        state_dict[i].copy_(param_dict[key])
    del param_dict