1. 程式人生 > 其它 >部分網路載入預訓練模型程式碼實現

部分網路載入預訓練模型程式碼實現

引自:https://www.cxyzjd.com/article/Gavinmiaoc/80514528

方式一: 自己網路和預訓練網路結構一致的層,使用預訓練網路對應層的引數批量初始化

檢視預訓練模型引數:

path = 'I:/迅雷下載/alexnet-owt-4df8aa71.pth'
pretrained_dict = torch.load(path)
for k, v in pretrained_dict.items():  # k 引數名 v 對應引數值
        print(k)
model_dict = model.state_dict()                                    # 取出自己網路的引數字典
pretrained_dict = torch.load("I:/迅雷下載/alexnet-owt-4df8aa71.pth")# 載入預訓練網路的引數字典
# 取出預訓練網路的引數字典
keys = []
for k, v in pretrained_dict.items():
       keys.append(k)
i = 0

# 自己網路和預訓練網路結構一致的層,使用預訓練網路對應層的引數初始化
for k, v in model_dict.items():
    if v.size() == pretrained_dict[keys[i]].size():
         model_dict[k] = pretrained_dict[keys[i]]
         #print(model_dict[k])
         i = i + 1
model.load_state_dict(model_dict)

方式二:自己網路和預訓練網路結構一致的層,按層初始化

# 加粗自己定義一個網路叫CNN
model = CNN()
model_dict = model.state_dict()                                    # 取出自己網路的引數

for k, v in model_dict.items():                                    # 檢視自己網路引數各層叫什麼名稱
       print(k)

pretrained_dict = torch.load("I:/迅雷下載/alexnet-owt-4df8aa71.pth")# 載入預訓練網路的引數
for k, v in pretrained_dict.items():                                    # 檢視預訓練網路引數各層叫什麼名稱
       print(k)


# 對應層賦值初始化
model_dict['conv1.0.weight'] = pretrained_dict['features.0.weight'] # 將自己網路的conv1.0層的權重初始化為預訓練網路features.0層的權重
model_dict['conv1.0.bias'] = pretrained_dict['features.0.bias']    # 將自己網路的conv1.0層的偏置項初始化為預訓練網路features.0層的偏置項

model_dict['conv2.1.weight'] = pretrained_dict['features.3.weight']
model_dict['conv1.1.bias'] = pretrained_dict['features.3.bias']

model_dict['conv2.1.weight'] = pretrained_dict['features.6.weight']
model_dict['conv2.1.bias'] = pretrained_dict['features.6.bias']

... ...