1. 程式人生 > >Pytorch 多GPU執行

Pytorch 多GPU執行

self.net = netword()
n_gpu = 1
if n_gpu==1:
    self.net = torch.nn.DataParallel(self.net).cuda(device=0)
else:
    gpus = []
    for i in range(n_gpu):
    	gpus.append(i)
    	self.net = torch.nn.DataParallel(self.net, device_ids=gpus).cuda()