部分網路載入預訓練模型程式碼實現
阿新 • • 發佈:2021-09-03
引自:https://www.cxyzjd.com/article/Gavinmiaoc/80514528
方式一: 自己網路和預訓練網路結構一致的層,使用預訓練網路對應層的引數批量初始化
檢視預訓練模型引數:
path = 'I:/迅雷下載/alexnet-owt-4df8aa71.pth'
pretrained_dict = torch.load(path)
for k, v in pretrained_dict.items(): # k 引數名 v 對應引數值
print(k)
model_dict = model.state_dict() # 取出自己網路的引數字典 pretrained_dict = torch.load("I:/迅雷下載/alexnet-owt-4df8aa71.pth")# 載入預訓練網路的引數字典 # 取出預訓練網路的引數字典 keys = [] for k, v in pretrained_dict.items(): keys.append(k) i = 0 # 自己網路和預訓練網路結構一致的層,使用預訓練網路對應層的引數初始化 for k, v in model_dict.items(): if v.size() == pretrained_dict[keys[i]].size(): model_dict[k] = pretrained_dict[keys[i]] #print(model_dict[k]) i = i + 1 model.load_state_dict(model_dict)
方式二:自己網路和預訓練網路結構一致的層,按層初始化
# 加粗自己定義一個網路叫CNN model = CNN() model_dict = model.state_dict() # 取出自己網路的引數 for k, v in model_dict.items(): # 檢視自己網路引數各層叫什麼名稱 print(k) pretrained_dict = torch.load("I:/迅雷下載/alexnet-owt-4df8aa71.pth")# 載入預訓練網路的引數 for k, v in pretrained_dict.items(): # 檢視預訓練網路引數各層叫什麼名稱 print(k) # 對應層賦值初始化 model_dict['conv1.0.weight'] = pretrained_dict['features.0.weight'] # 將自己網路的conv1.0層的權重初始化為預訓練網路features.0層的權重 model_dict['conv1.0.bias'] = pretrained_dict['features.0.bias'] # 將自己網路的conv1.0層的偏置項初始化為預訓練網路features.0層的偏置項 model_dict['conv2.1.weight'] = pretrained_dict['features.3.weight'] model_dict['conv1.1.bias'] = pretrained_dict['features.3.bias'] model_dict['conv2.1.weight'] = pretrained_dict['features.6.weight'] model_dict['conv2.1.bias'] = pretrained_dict['features.6.bias'] ... ...