1. 程式人生 > 程式設計 >PyTorch載入預訓練模型例項(pretrained)

PyTorch載入預訓練模型例項(pretrained)

使用預訓練模型的程式碼如下:

# 載入預訓練模型
 resNet50 = models.resnet50(pretrained=True)
 ResNet50 = ResNet(Bottleneck,[3,4,6,3],num_classes=2)

 # 讀取引數
 pretrained_dict = resNet50.state_dict()
 model_dict = ResNet50.state_dict()

 # 將pretained_dict裡不屬於model_dict的鍵剔除掉
 pretrained_dict = {k: v for k,v in pretrained_dict.items() if k in model_dict}

 # 更新現有的model_dict
 model_dict.update(pretrained_dict)

 # 載入真正需要的state_dict
 ResNet50.load_state_dict(model_dict)

以上這篇PyTorch載入預訓練模型例項(pretrained)就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。