1. 程式人生 > 程式設計 >解決pytorch多GPU訓練儲存的模型,在單GPU環境下加載出錯問題

解決pytorch多GPU訓練儲存的模型,在單GPU環境下加載出錯問題

背景

在公司用多卡訓練模型,得到權值檔案後儲存,然後回到實驗室,沒有多卡的環境,用單卡訓練,載入模型時出錯,因為單卡機器上,沒有使用DataParallel來載入模型,所以會出現載入錯誤。

原因

DataParallel包裝的模型在儲存時,權值引數前面會帶有module字元,然而自己在單卡環境下,沒有用DataParallel包裝的模型權值引數不帶module。本質上儲存的權值檔案是一個有序字典。

解決方法

1.在單卡環境下,用DataParallel包裝模型。

2.自己重寫Load函式,靈活。

from collections import OrderedDict
def myOwnLoad(model,check):
  modelState = model.state_dict()
  tempState = OrderedDict()
  for i in range(len(check.keys())-2):
    print modelState.keys()[i],check.keys()[i]
    tempState[modelState.keys()[i]] = check[check.keys()[i]]
  temp = [[0.02]*1024 for i in range(200)] # mean=0,std=0.02
  tempState['myFc.weight'] = torch.normal(mean=0,std=torch.FloatTensor(temp)).cuda()
  tempState['myFc.bias']  = torch.normal(mean=0,std=torch.FloatTensor([0]*200)).cuda()

  model.load_state_dict(tempState)
  return model

補充知識:Pytorch:多GPU訓練網路與單GPU訓練網路儲存模型的區別

測試環境:Python3.6 + Pytorch0.4

在pytorch中,使用多GPU訓練網路需要用到 【nn.DataParallel】:

gpu_ids = [0,1,2,3]
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能單GPU執行
net = LeNet()
if len(gpu_ids) > 1:
  net = nn.DataParallel(net,device_ids=gpu_ids)
net = net.to(device)

而使用單GPU訓練網路:

device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能單GPU執行
net = LeNet().to(device)

由於多GPU訓練使用了 nn.DataParallel(net,device_ids=gpu_ids) 對網路進行封裝,因此在原始網路結構中添加了一層module。網路結構如下:

DataParallel(
 (module): LeNet(
  (conv1): Conv2d(3,6,kernel_size=(5,5),stride=(1,1))
  (conv2): Conv2d(6,16,1))
  (fc1): Linear(in_features=400,out_features=120,bias=True)
  (fc2): Linear(in_features=120,out_features=84,bias=True)
  (fc3): Linear(in_features=84,out_features=10,bias=True)
 )
)

而不使用多GPU訓練的網路結構如下:

LeNet(
 (conv1): Conv2d(3,1))
 (conv2): Conv2d(6,1))
 (fc1): Linear(in_features=400,bias=True)
 (fc2): Linear(in_features=120,bias=True)
 (fc3): Linear(in_features=84,bias=True)
)

由於在測試模型時不需要用到多GPU測試,因此在儲存模型時應該把module層去掉。如下:

if len(gpu_ids) > 1:
  t.save(net.module.state_dict(),"model.pth")
else:
  t.save(net.state_dict(),"model.pth")

以上這篇解決pytorch多GPU訓練儲存的模型,在單GPU環境下加載出錯問題就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。