1. 程式人生 > 程式設計 >PyTorch的torch.cat用法

PyTorch的torch.cat用法

1. 字面理解:

torch.cat是將兩個張量(tensor)拼接在一起,cat是concatnate的意思,即拼接,聯絡在一起。

2. 例子理解

>>> import torch
>>> A=torch.ones(2,3) #2x3的張量(矩陣)                   
>>> A
tensor([[ 1.,1.,1.],[ 1.,1.]])
>>> B=2*torch.ones(4,3)#4x3的張量(矩陣)                  
>>> B
tensor([[ 2.,2.,2.],[ 2.,2.]])
>>> C=torch.cat((A,B),0)#按維數0(行)拼接
>>> C
tensor([[ 1.,2.]])
>>> C.size()
torch.Size([6,3])
>>> D=2*torch.ones(2,4) #2x4的張量(矩陣)
>>> C=torch.cat((A,D),1)#按維數1(列)拼接
>>> C
tensor([[ 1.,2.]])
>>> C.size()
torch.Size([2,7])

上面給出了兩個張量A和B,分別是2行3列,4行3列。即他們都是2維張量。因為只有兩維,這樣在用torch.cat拼接的時候就有兩種拼接方式:按行拼接和按列拼接。即所謂的維數0和維數1.

C=torch.cat((A,0)就表示按維數0(行)拼接A和B,也就是豎著拼接,A上B下。此時需要注意:列數必須一致,即維數1數值要相同,這裡都是3列,方能列對齊。拼接後的C的第0維是兩個維數0數值和,即2+4=6.

C=torch.cat((A,1)就表示按維數1(列)拼接A和B,也就是橫著拼接,A左B右。此時需要注意:行數必須一致,即維數0數值要相同,這裡都是2行,方能行對齊。拼接後的C的第1維是兩個維數1數值和,即3+4=7.

從2維例子可以看出,使用torch.cat((A,dim)時,除拼接維數dim數值可不同外其餘維數數值需相同,方能對齊。

3.例項

在深度學習處理影象時,常用的有3通道的RGB彩色影象及單通道的灰度圖。張量size為cxhxw,即通道數x影象高度x影象寬度。在用torch.cat拼接兩張影象時一般要求影象大小一致而通道數可不一致,即h和w同,c可不同。當然實際有3種拼接方式,另兩種好像不常見。比如經典網路結構:U-Net

PyTorch的torch.cat用法

裡面用到4次torch.cat,其中copy and crop操作就是通過torch.cat來實現的。可以看到通過上取樣(up-conv 2x2)將原始影象h和w變為原來2倍,再和左邊直接copy過來的同樣h,w的影象拼接。這樣做,可以有效利用原始結構資訊。

4.總結

使用torch.cat((A,dim)時,除拼接維數dim數值可不同外其餘維數數值需相同,方能對齊。

補充知識:PyTorch的concat也就是torch.cat例項

我就廢話不多說了,大家還是直接看程式碼吧~

import torch
a = torch.ones([1,2])
b = torch.ones([1,2])
torch.cat([a,b],1)
 1 1 1 1
[torch.FloatTensor of size 1x4]

以上這篇PyTorch的torch.cat用法就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。