PyTorch載入預訓練模型例項(pretrained)
阿新 • • 發佈:2020-01-17
使用預訓練模型的程式碼如下:
# 載入預訓練模型 resNet50 = models.resnet50(pretrained=True) ResNet50 = ResNet(Bottleneck,[3,4,6,3],num_classes=2) # 讀取引數 pretrained_dict = resNet50.state_dict() model_dict = ResNet50.state_dict() # 將pretained_dict裡不屬於model_dict的鍵剔除掉 pretrained_dict = {k: v for k,v in pretrained_dict.items() if k in model_dict} # 更新現有的model_dict model_dict.update(pretrained_dict) # 載入真正需要的state_dict ResNet50.load_state_dict(model_dict)
以上這篇PyTorch載入預訓練模型例項(pretrained)就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。