torch多顯示卡訓練負載均衡
本文主要解決pytorch在進行模型訓練時出現GPU的0卡佔用視訊記憶體比其他卡要多的問題。
如下圖所示:本機GPU卡為TITAN RTX,視訊記憶體24220M,batch_size = 9,用了三張卡。第0卡視訊記憶體佔用24207M,這時僅僅是剛開始執行,資料只是少量的移到顯示卡上,如果資料在多點,0卡的視訊記憶體肯定撐爆。出現0卡視訊記憶體更高的原因:網路在反向傳播的時候,計算loss的梯度預設都在0卡上計算。因此會比其他顯示卡多用一些視訊記憶體,具體多用多少,主要還要看網路的結構。
因此,為了防止訓練由於 out of memory 而中斷。比較笨的辦法是將batch_size設為6,即每張卡放2條資料。
有沒有發現問題?視訊記憶體只用了1,2卡的視訊記憶體只用了16G不到。就因為0卡可能會超那麼一點點視訊記憶體,而犧牲了batch_size。
那麼沒有更優雅的方法呢?答案是肯定的。那就是借用下transformer-xl中用到的 BalancedDataParallel類。程式碼如下(程式碼出處):
import torch from torch.nn.parallel.data_parallel import DataParallel from torch.nn.parallel.parallel_apply import parallel_apply from torch.nn.parallel._functions import Scatter
def scatter(inputs, target_gpus, chunk_sizes, dim=0):
r"""
Slices tensors into approximately equal chunks and
distributes them across given GPUs. Duplicates
references to objects that are not tensors.
“”"
def scatter_map(obj): if isinstance(obj, torch.Tensor): try: return Scatter.apply(target_gpus, chunk_sizes, dim, obj) except Exception: print('obj', obj.size()) print('dim', dim) print('chunk_sizes', chunk_sizes) quit() if isinstance(obj, tuple) and len(obj) > 0: return list(zip(*map(scatter_map, obj))) if isinstance(obj, list) and len(obj) > 0: return list(map(list, zip(*map(scatter_map, obj)))) if isinstance(obj, dict) and len(obj) > 0: return list(map(type(obj), zip(*map(scatter_map, obj.items())))) return [obj for targets in target_gpus] # After scatter_map is called, a scatter_map cell will exist. This cell # has a reference to the actual function scatter_map, which has references # to a closure that has a reference to the scatter_map cell (because the # fn is recursive). To avoid this reference cycle, we set the function to # None, clearing the cell try: return scatter_map(inputs) finally: scatter_map = None
def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
“”“Scatter with support for kwargs dictionary”""
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs
class BalancedDataParallel(DataParallel):
def __init__(self, gpu0_bsz, *args, **kwargs):
self.gpu0_bsz = gpu0_bsz
super().__init__(*args, **kwargs)
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
if self.gpu0_bsz == 0:
device_ids = self.device_ids[1:]
else:
device_ids = self.device_ids
inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids)
if self.gpu0_bsz == 0:
replicas = replicas[1:]
outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
return self.gather(outputs, self.output_device)
def parallel_apply(self, replicas, device_ids, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, device_ids)
def scatter(self, inputs, kwargs, device_ids):
bsz = inputs[0].size(self.dim)
num_dev = len(self.device_ids)
gpu0_bsz = self.gpu0_bsz
bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
if gpu0_bsz < bsz_unit:
chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
delta = bsz - sum(chunk_sizes)
for i in range(delta):
chunk_sizes[i + 1] += 1
if gpu0_bsz == 0:
chunk_sizes = chunk_sizes[1:]
else:
return super().scatter(inputs, kwargs, device_ids)
return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
- 1
- 2
- 3
- 4
- 5
- 6
- 7
- 8
- 9
- 10
- 11
- 12
- 13
- 14
- 15
- 16
- 17
- 18
- 19
- 20
- 21
- 22
- 23
- 24
- 25
- 26
- 27
- 28
- 29
- 30
- 31
- 32
- 33
- 34
- 35
- 36
- 37
- 38
- 39
- 40
- 41
- 42
- 43
- 44
- 45
- 46
- 47
- 48
- 49
- 50
- 51
- 52
- 53
- 54
- 55
- 56
- 57
- 58
- 59
- 60
- 61
- 62
- 63
- 64
- 65
- 66
- 67
- 68
- 69
- 70
- 71
- 72
- 73
- 74
- 75
- 76
- 77
- 78
- 79
- 80
- 81
- 82
- 83
- 84
- 85
- 86
- 87
- 88
- 89
- 90
- 91
- 92
- 93
- 94
從程式碼中可以看到,BalancedDataParallel繼承了 torch.nn.DataParallel,之後通過自定義0卡batch_size的大小gpu0_bsz,即讓0卡少一點資料。均衡0卡和其他卡的視訊記憶體佔用。呼叫程式碼如下:
import BalancedDataParallel
if n_gpu > 1:
model = BalancedDataParallel(2, model, dim=0).to(device)
# model = torch.nn.DataParallel(model)
- 1
- 2
- 3
- 4
- 5
gpu0_bsz:GPU的0卡batch_size;
model:模型;
dim:batch所在維度
因此,我們不妨將剛才的batch_size設為8,即gpu0_bsz=2試試,結果如下:
成功的將batch_size從6調整到了8,因為0卡少放了一個batch,因此,會比其他的卡少。但是犧牲一張卡的視訊記憶體,換取其他卡的視訊記憶體,最終提高了batch_size,還是可取得。特別是當卡數目比較多的時候,這種方法的優勢就更明顯了。