利用resnet預訓練權重,出現“bn1.num_batches_tracked”或者“layer.0.bn1.num_batches_tracked" 的解決辦法
阿新 • • 發佈:2020-08-20
報錯的原因在於Pytorch0.4之後,在BN層後新增加了track_running_stats這個引數。
在呼叫預訓練引數模型是,官方給定的預訓練模型是在pytorch0.4之前,因此,呼叫預訓練引數時,需要過濾掉“num_batches_tracked”。
以resnet50為例:
為了載入不同層的權重,採用兩個函式,如下:load_partial_param用於載入layer1, layer2, layer3, layer4的權重權重,load_specific_param用於載入第一層的權重引數。
為了避免“num_batches_tracked”報錯,採用下面的程式碼即可,更改部分為紅色字型(方法簡單,但可以滿足要求)。
def load_partial_param(self, state_dict, model_index, model_path): param_dict = torch.load(model_path) for i in state_dict: key = 'layer{}.'.format(model_index)+i if 'tracked' in key[-7:]: continue state_dict[i].copy_(param_dict[key]) del param_dict
def load_specific_param(self, state_dict, param_name, model_path): param_dict = torch.load(model_path) for i in state_dict: key = param_name + '.' + i if 'num_batches_tracked' in key: continue state_dict[i].copy_(param_dict[key]) del param_dict