1. 程式人生 > >Pytorch只加載自己需要的模型引數(修改模型後)

Pytorch只加載自己需要的模型引數(修改模型後)

給定一個預訓練模型,如果你對模型結構做了一定的修改,那麼可以只加載未改變的模型引數,從而加快模型的訓練。程式碼如下:
pretrained_dict = ‘…….pkl’#預訓練模型引數儲存地址
model_dict = model.state_dict() #自己的模型引數變數

pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#去除一些不需要的引數
model_dict.update(pretrained_dict)#引數更新
model.load_state_dict(model_dict)#載入