1. 程式人生 > 其它 >torch多顯示卡訓練負載均衡

torch多顯示卡訓練負載均衡

技術標籤:gpu人工智慧深度學習pytorch

本文主要解決pytorch在進行模型訓練時出現GPU的0卡佔用視訊記憶體比其他卡要多的問題。
如下圖所示:本機GPU卡為TITAN RTX,視訊記憶體24220M,batch_size = 9,用了三張卡。第0卡視訊記憶體佔用24207M,這時僅僅是剛開始執行,資料只是少量的移到顯示卡上,如果資料在多點,0卡的視訊記憶體肯定撐爆。出現0卡視訊記憶體更高的原因:網路在反向傳播的時候,計算loss的梯度預設都在0卡上計算。因此會比其他顯示卡多用一些視訊記憶體,具體多用多少,主要還要看網路的結構。
在這裡插入圖片描述
因此,為了防止訓練由於 out of memory 而中斷。比較笨的辦法是將batch_size設為6,即每張卡放2條資料。

batch_size = 6時,其他不變,如下圖所示
在這裡插入圖片描述
有沒有發現問題?視訊記憶體只用了1,2卡的視訊記憶體只用了16G不到。就因為0卡可能會超那麼一點點視訊記憶體,而犧牲了batch_size。
那麼沒有更優雅的方法呢?答案是肯定的。那就是借用下transformer-xl中用到的 BalancedDataParallel類。程式碼如下(程式碼出處):

import torch
from torch.nn.parallel.data_parallel import DataParallel
from torch.nn.parallel.parallel_apply import parallel_apply
from torch.nn.parallel._functions import Scatter

def scatter(inputs, target_gpus, chunk_sizes, dim=0):
r"""
Slices tensors into approximately equal chunks and
distributes them across given GPUs. Duplicates
references to objects that are not tensors.
“”"

def scatter_map(obj):
    if isinstance(obj, torch.Tensor):
        try:
            return Scatter.apply(target_gpus, chunk_sizes, dim, obj)
        except Exception:
            print('obj', obj.size())
            print('dim', dim)
            print('chunk_sizes', chunk_sizes)
            quit()
    if isinstance(obj, tuple) and len(obj) > 0:
        return list(zip(*map(scatter_map, obj)))
    if isinstance(obj, list) and len(obj) > 0:
        return list(map(list, zip(*map(scatter_map, obj))))
    if isinstance(obj, dict) and len(obj) > 0:
        return list(map(type(obj), zip(*map(scatter_map, obj.items()))))
    return [obj for targets in target_gpus]

# After scatter_map is called, a scatter_map cell will exist. This cell
# has a reference to the actual function scatter_map, which has references
# to a closure that has a reference to the scatter_map cell (because the
# fn is recursive). To avoid this reference cycle, we set the function to
# None, clearing the cell
try:
    return scatter_map(inputs)
finally:
    scatter_map = None

def scatter_kwargs(inputs, kwargs, target_gpus, chunk_sizes, dim=0):
“”“Scatter with support for kwargs dictionary”""
inputs = scatter(inputs, target_gpus, chunk_sizes, dim) if inputs else []
kwargs = scatter(kwargs, target_gpus, chunk_sizes, dim) if kwargs else []
if len(inputs) < len(kwargs):
inputs.extend([() for _ in range(len(kwargs) - len(inputs))])
elif len(kwargs) < len(inputs):
kwargs.extend([{} for _ in range(len(inputs) - len(kwargs))])
inputs = tuple(inputs)
kwargs = tuple(kwargs)
return inputs, kwargs

class BalancedDataParallel(DataParallel):

def __init__(self, gpu0_bsz, *args, **kwargs):
    self.gpu0_bsz = gpu0_bsz
    super().__init__(*args, **kwargs)

def forward(self, *inputs, **kwargs):
    if not self.device_ids:
        return self.module(*inputs, **kwargs)
    if self.gpu0_bsz == 0:
        device_ids = self.device_ids[1:]
    else:
        device_ids = self.device_ids
    inputs, kwargs = self.scatter(inputs, kwargs, device_ids)
    if len(self.device_ids) == 1:
        return self.module(*inputs[0], **kwargs[0])
    replicas = self.replicate(self.module, self.device_ids)
    if self.gpu0_bsz == 0:
        replicas = replicas[1:]
    outputs = self.parallel_apply(replicas, device_ids, inputs, kwargs)
    return self.gather(outputs, self.output_device)

def parallel_apply(self, replicas, device_ids, inputs, kwargs):
    return parallel_apply(replicas, inputs, kwargs, device_ids)

def scatter(self, inputs, kwargs, device_ids):
    bsz = inputs[0].size(self.dim)
    num_dev = len(self.device_ids)
    gpu0_bsz = self.gpu0_bsz
    bsz_unit = (bsz - gpu0_bsz) // (num_dev - 1)
    if gpu0_bsz &lt; bsz_unit:
        chunk_sizes = [gpu0_bsz] + [bsz_unit] * (num_dev - 1)
        delta = bsz - sum(chunk_sizes)
        for i in range(delta):
            chunk_sizes[i + 1] += 1
        if gpu0_bsz == 0:
            chunk_sizes = chunk_sizes[1:]
    else:
        return super().scatter(inputs, kwargs, device_ids)
    return scatter_kwargs(inputs, kwargs, device_ids, chunk_sizes, dim=self.dim)
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • 64
  • 65
  • 66
  • 67
  • 68
  • 69
  • 70
  • 71
  • 72
  • 73
  • 74
  • 75
  • 76
  • 77
  • 78
  • 79
  • 80
  • 81
  • 82
  • 83
  • 84
  • 85
  • 86
  • 87
  • 88
  • 89
  • 90
  • 91
  • 92
  • 93
  • 94

從程式碼中可以看到,BalancedDataParallel繼承了 torch.nn.DataParallel,之後通過自定義0卡batch_size的大小gpu0_bsz,即讓0卡少一點資料。均衡0卡和其他卡的視訊記憶體佔用。呼叫程式碼如下:

import BalancedDataParallel

if n_gpu > 1:
model = BalancedDataParallel(2, model, dim=0).to(device)
# model = torch.nn.DataParallel(model)

  • 1
  • 2
  • 3
  • 4
  • 5

gpu0_bsz:GPU的0卡batch_size;
model:模型;
dim:batch所在維度

因此,我們不妨將剛才的batch_size設為8,即gpu0_bsz=2試試,結果如下:
在這裡插入圖片描述
成功的將batch_size從6調整到了8,因為0卡少放了一個batch,因此,會比其他的卡少。但是犧牲一張卡的視訊記憶體,換取其他卡的視訊記憶體,最終提高了batch_size,還是可取得。特別是當卡數目比較多的時候,這種方法的優勢就更明顯了。