1. 程式人生 > 實用技巧 >tensor的拼接與拆分

tensor的拼接與拆分

tensor的拼接與拆分

目錄

cat函式

例子:成績單的合併

【班級1~4 學生 得分】

【班級5~9 學生 得分】

在0維進行合併,非cat的維度必須一致

a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
c = torch.cat([a,b],dim=0)
c.shape()
#[9,32,8]

stack函式

會新新增一個維度,要保證兩個stack的tensor的維度一摸一樣,在理解方面是添加了新的概念在裡面。

例子:

一班:【32個學生 每個學生8門課程】

二班:【32個學生 每個學生8門課程】

stack之後變為【兩個班級 每個班級32個學生 每個學生有8門課程】

a = torch.rand(32,8)
b = torch.rand(32,8)
torch.stack([a,b],dim=0).shape
#[2 32 8]

split函式

split函式按照長度來拆分

例子1:

引數說明:【1,1】表示前面的長度為1,後面的長度也是1

a = torch.rand(2,32,8)
b,c = torch.split([1,1],dim=0)
b.shape
#[1,32,8]
c.shape()
#[1,32,8]

例子2:

引數說明:【2,1】表示前面的長度為2,後面的長度為1

(不規則分割的引數含義)

a = torch.rand(3,32,8)
b,c = torch.split([2,1],dim=0)
b.shape
#[2,32,8]
c.shape()
#[1,32,8]

chunk函式

根據數量來進行分割(儘量實現整除,後面除不盡的留給最後)

例子:

a = torch.rand(6,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)

#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])

例子2:

a = torch.rand(5,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)

#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([1, 32, 8])

例子3:

a = torch.rand(5,32,8)
b,c= torch.chunk(a,2,dim=0)
print(b.shape)
print(c.shape)

#torch.Size([3, 32, 8])
#torch.Size([2, 32, 8])