1. 程式人生 > >pytorch任意批量將資料傳入cuda

pytorch任意批量將資料傳入cuda

import torch

def cuda(*pargs):
    if torch.cuda.is_available():
        return (data.cuda() for data in pargs)

if __name__ == '__main__':

    x = torch.randn(1,3,256,256)
    y = torch.randn(1,3,512,512)

    a,b = cuda(x,y)
    print(a.cpu().data.numpy().shape)
    print(b.cpu().data.numpy().shape)