1. 程式人生 > >PyTorch的torch.cat

PyTorch的torch.cat

import torch

a = torch.ones([1,2])

b = torch.ones([1,2])

torch.cat([a,b],1)

1111

[torch.FloatTensor of size 1x4]

如果第二個引數是1,torch.cat就是將a,b 按列放在一起,大小為torch.Size([1,4])。如果第二個引數是0,則按行

行放在一起,大小為 torch.Size([2, 2]) 。

自己的一點體會:

在深度學習處理影象的時候,經常要考慮將多張不同圖片輸入到網路,這時需要用torch.cat([image1,image2],1), 第二個引數是1;

如果將多個顏色通道拼成一張圖片,比如將L, A, B三通道或R, G, B三通道組合成一張彩色圖片,此時應用torch.cat([L,A,B,].0),第二個引數為0.