pytorch學習-Tensor的組合與分塊 02
阿新 • • 發佈:2020-09-01
Tensor的組合與分塊
組合操作是指將不同的Tensor疊加起來, 主要有torch.cat()和torch.stack()兩個函式。 cat即concatenate的意思, 是指沿著已有的資料的某一維度進行拼接, 操作後資料的總維數不變, 在進行拼接時, 除了拼接的維度之外, 其他維度必須相同。 而torch.stack()函式指新增維度, 並按照指定的維度進行疊加,
1 import torch 2 3 # 建立兩個2×2的Tensor 4 a = torch.Tensor([[1,2],[3,4]]) 5 print(a,a.shape) 6 7View Codeb = torch.Tensor([[5,6],[7,8]]) 8 print(b,b.shape) 9 10 # 以第一維進行拼接, 則變成4×2的矩陣 11 c = torch.cat([a,b],0) 12 print(c,c.shape) 13 14 # 以第二維進行拼接, 則變成2*4的矩陣 15 d = torch.cat([a,b],1) 16 print(d,d.size())
結果輸出:
1 tensor([[1., 2.], 2 [3., 4.]]) torch.Size([2, 2]) 3 tensor([[5., 6.],View Code4 [7., 8.]]) torch.Size([2, 2]) 5 tensor([[1., 2.], 6 [3., 4.], 7 [5., 6.], 8 [7., 8.]]) torch.Size([4, 2]) 9 tensor([[1., 2., 5., 6.], 10 [3., 4., 7., 8.]]) torch.Size([2, 4])
1 import torch 2 3 # 建立兩個2×2的Tensor 4 a = torch.Tensor([[1,2],[3,4]])View Code5 print(a,a.shape) 6 7 >> tensor([[1., 2.], 8 [3., 4.]]) torch.Size([2, 2]) 9 10 b = torch.Tensor([[5,6],[7,8]]) 11 print(b,b.shape) 12 13 >> tensor([[5., 6.], 14 [7., 8.]]) torch.Size([2, 2]) 15 16 # 以第0維進行stack, 疊加的基本單位為序列本身, 即a與b, 因此輸出[a, b], 輸出維度為2×2×2 17 d=torch.stack([a,b],0) 18 print(d, d.size()) 19 >> tensor([[[1., 2.], 20 [3., 4.]], 21 22 [[5., 6.], 23 [7., 8.]]]) torch.Size([2, 2, 2]) 24 25 # 以第1維進行stack, 疊加的基本單位為每一行, 輸出維度為2×2×2 26 e=torch.stack([a,b],1) 27 print(e, e.shape) 28 29 >> tensor([[[1., 2.], 30 [5., 6.]], 31 32 [[3., 4.], 33 [7., 8.]]]) torch.Size([2, 2, 2]) 34 35 # 以第2維進行stack, 疊加的基本單位為每一行的每一個元素, 輸出維度為2×2×2 36 f=torch.stack([a,b],2) 37 print(f, f.shape) 38 39 >> tensor([[[1., 5.], 40 [2., 6.]], 41 42 [[3., 7.], 43 [4., 8.]]]) torch.Size([2, 2, 2])
分塊則是與組合相反的操作, 指將Tensor分割成不同的子Tensor,主要有torch.chunk()與torch.split()兩個函式, 前者需要指定分塊的數量,而後者則需要指定每一塊的大小, 以整型或者list來表示。 具體示例如下 :