解決pytorch多GPU訓練儲存的模型,在單GPU環境下加載出錯問題
阿新 • • 發佈:2020-06-23
背景
在公司用多卡訓練模型,得到權值檔案後儲存,然後回到實驗室,沒有多卡的環境,用單卡訓練,載入模型時出錯,因為單卡機器上,沒有使用DataParallel來載入模型,所以會出現載入錯誤。
原因
DataParallel包裝的模型在儲存時,權值引數前面會帶有module字元,然而自己在單卡環境下,沒有用DataParallel包裝的模型權值引數不帶module。本質上儲存的權值檔案是一個有序字典。
解決方法
1.在單卡環境下,用DataParallel包裝模型。
2.自己重寫Load函式,靈活。
from collections import OrderedDict def myOwnLoad(model,check): modelState = model.state_dict() tempState = OrderedDict() for i in range(len(check.keys())-2): print modelState.keys()[i],check.keys()[i] tempState[modelState.keys()[i]] = check[check.keys()[i]] temp = [[0.02]*1024 for i in range(200)] # mean=0,std=0.02 tempState['myFc.weight'] = torch.normal(mean=0,std=torch.FloatTensor(temp)).cuda() tempState['myFc.bias'] = torch.normal(mean=0,std=torch.FloatTensor([0]*200)).cuda() model.load_state_dict(tempState) return model
補充知識:Pytorch:多GPU訓練網路與單GPU訓練網路儲存模型的區別
測試環境:Python3.6 + Pytorch0.4
在pytorch中,使用多GPU訓練網路需要用到 【nn.DataParallel】:
gpu_ids = [0,1,2,3] device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能單GPU執行 net = LeNet() if len(gpu_ids) > 1: net = nn.DataParallel(net,device_ids=gpu_ids) net = net.to(device)
而使用單GPU訓練網路:
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能單GPU執行
net = LeNet().to(device)
由於多GPU訓練使用了 nn.DataParallel(net,device_ids=gpu_ids) 對網路進行封裝,因此在原始網路結構中添加了一層module。網路結構如下:
DataParallel( (module): LeNet( (conv1): Conv2d(3,6,kernel_size=(5,5),stride=(1,1)) (conv2): Conv2d(6,16,1)) (fc1): Linear(in_features=400,out_features=120,bias=True) (fc2): Linear(in_features=120,out_features=84,bias=True) (fc3): Linear(in_features=84,out_features=10,bias=True) ) )
而不使用多GPU訓練的網路結構如下:
LeNet( (conv1): Conv2d(3,1)) (conv2): Conv2d(6,1)) (fc1): Linear(in_features=400,bias=True) (fc2): Linear(in_features=120,bias=True) (fc3): Linear(in_features=84,bias=True) )
由於在測試模型時不需要用到多GPU測試,因此在儲存模型時應該把module層去掉。如下:
if len(gpu_ids) > 1: t.save(net.module.state_dict(),"model.pth") else: t.save(net.state_dict(),"model.pth")
以上這篇解決pytorch多GPU訓練儲存的模型,在單GPU環境下加載出錯問題就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。