tensor的拼接與拆分
阿新 • • 發佈:2020-09-01
tensor的拼接與拆分
目錄
cat函式
例子:成績單的合併
【班級1~4 學生 得分】
【班級5~9 學生 得分】
在0維進行合併,非cat的維度必須一致
a = torch.rand(4,32,8)
b = torch.rand(5,32,8)
c = torch.cat([a,b],dim=0)
c.shape()
#[9,32,8]
stack函式
會新新增一個維度,要保證兩個stack的tensor的維度一摸一樣
,在理解方面是添加了新的概念在裡面。
例子:
一班:【32個學生 每個學生8門課程】
二班:【32個學生 每個學生8門課程】
stack之後變為【兩個班級 每個班級32個學生 每個學生有8門課程】
a = torch.rand(32,8)
b = torch.rand(32,8)
torch.stack([a,b],dim=0).shape
#[2 32 8]
split函式
split函式按照長度來拆分
例子1:
引數說明:【1,1】表示前面的長度為1,後面的長度也是1
a = torch.rand(2,32,8)
b,c = torch.split([1,1],dim=0)
b.shape
#[1,32,8]
c.shape()
#[1,32,8]
例子2:
引數說明:【2,1】表示前面的長度為2,後面的長度為1
a = torch.rand(3,32,8)
b,c = torch.split([2,1],dim=0)
b.shape
#[2,32,8]
c.shape()
#[1,32,8]
chunk函式
根據數量來進行分割(儘量實現整除,後面除不盡的留給最後)
例子:
a = torch.rand(6,32,8) b,c,d= torch.chunk(a,3,dim=0) print(b.shape) print(c.shape) print(d.shape) #torch.Size([2, 32, 8]) #torch.Size([2, 32, 8]) #torch.Size([2, 32, 8])
例子2:
a = torch.rand(5,32,8)
b,c,d= torch.chunk(a,3,dim=0)
print(b.shape)
print(c.shape)
print(d.shape)
#torch.Size([2, 32, 8])
#torch.Size([2, 32, 8])
#torch.Size([1, 32, 8])
例子3:
a = torch.rand(5,32,8)
b,c= torch.chunk(a,2,dim=0)
print(b.shape)
print(c.shape)
#torch.Size([3, 32, 8])
#torch.Size([2, 32, 8])