pytorch tensor 的拼接和拆分
阿新 • • 發佈:2020-08-28
torch.cat
import torch a=torch.randn(3,4) #隨機生成一個shape(3,4)的tensort b=torch.randn(2,4) #隨機生成一個shape(2,4)的tensor print("a:") print(a) print("b:") print(b) print("拼接結果:") #print(torch.cat([a,b],dim=1))#由於3!=2(a,b的行數不同故無法拼接) print(torch.cat([a,b],dim=0))#也就是原來的第0維維數增多 #print(torch.cat([a,b],dim=0))返回一個shape(5,4)的tensor#把a和b拼接成一個shape(5,4)的tensor, #可理解為沿著行增加的方向(即縱向)拼接
a: tensor([[ 2.1754, 1.4698, 1.4103, -0.2498], [ 0.3248, -1.9372, -0.9310, -0.3833], [-0.3603, -0.0271, -0.1942, -0.0345]]) b: tensor([[-0.3917, 0.1332, -1.0066, 0.3633], [-0.2378, 0.5224, 1.1371, 0.6401]]) 拼接結果: tensor([[ 2.1754, 1.4698, 1.4103, -0.2498], [0.3248, -1.9372, -0.9310, -0.3833], [-0.3603, -0.0271, -0.1942, -0.0345], [-0.3917, 0.1332, -1.0066, 0.3633], [-0.2378, 0.5224, 1.1371, 0.6401]])
torch.stack
要求:兩個tensor拼接前的形狀完全一致
import torch a=torch.randn(3,4) b=torch.randn(3,4) print("a: ") print(a) print("b: ") print(b) c=torch.stack([a,b],dim=0)#返回一個shape(2,3,4)的tensor,新增的維度2分別指向a和b # print("c: ") # print(c) # a: # tensor([[ 0.9722, 0.7518, 0.8787, 1.1068], # [-0.3760, -0.3623, -0.9563, 0.3909], # [ 0.7292, -0.0121, -0.4910, 2.1195]]) # b: # tensor([[-0.4713, -2.5941, 1.8245, 0.2314], # [ 1.3405, 0.3472, 1.1083, 0.7682], # [-1.1995, 0.6853, -0.7180, 0.7114]]) # c: # tensor([[[ 0.9722, 0.7518, 0.8787, 1.1068], # [-0.3760, -0.3623, -0.9563, 0.3909], # [ 0.7292, -0.0121, -0.4910, 2.1195]], # # [[-0.4713, -2.5941, 1.8245, 0.2314], # [ 1.3405, 0.3472, 1.1083, 0.7682], # [-1.1995, 0.6853, -0.7180, 0.7114]]]) d=torch.stack([a,b],dim=1) #返回一個shape(3,2,4)的tensor,新增的維度2分別指向相應的a的第i行和b的第i行 print("d: ") print(d) # a: # tensor([[ 0.9923, 0.2121, -0.8024, 0.4230], # [-0.6697, -2.7528, -1.2073, 0.9505], # [ 0.5162, -0.9078, -0.6087, -0.4061]]) # b: # tensor([[ 1.7505, 1.0785, 0.8404, 0.2812], # [ 0.9416, -0.7041, -1.6120, 0.3687], # [ 0.4658, 0.1827, 0.2341, -0.1813]]) # d: # tensor([[[ 0.9923, 0.2121, -0.8024, 0.4230], # [ 1.7505, 1.0785, 0.8404, 0.2812]], # # [[-0.6697, -2.7528, -1.2073, 0.9505], # [ 0.9416, -0.7041, -1.6120, 0.3687]], # # [[ 0.5162, -0.9078, -0.6087, -0.4061], # [ 0.4658, 0.1827, 0.2341, -0.1813]]])
總結:
這裡的關鍵詞引數dim的理解和cat方法中有些區別。
cat方法中可以理解為原tensor的維度,dim=0,就是沿著原來的0軸進行拼接,dim=1,就是沿著原來的1軸進行拼接。
stack方法中的dim則是指向新增維度的位置,dim=0,就是在新形成的tensor的維度的第0個位置新插入維度
下面解釋tensor的拆分split
split
是根據長度去拆分tensor
import torch a=torch.randn(3,4) print('a :') print(a) print("按維度0拆分 : ")print(a.split([1,2],dim=0)) #把維度0按照長度[1,2]拆分,形成2個tensor, #shape(1,4)和shape(2,4) # a : # tensor([[ 0.2626, 0.9178, -1.3622, -0.9441], # [-0.1259, -0.3336, 0.2441, -0.2219], # [ 1.5535, 0.7683, -1.7978, -1.1680]]) # 按維度0拆分 : # (tensor([[ 0.2626, 0.9178, -1.3622, -0.9441]]), tensor([[-0.1259, -0.3336, 0.2441, -0.2219], # [ 1.5535, 0.7683, -1.7978, -1.1680]])) print("按維度1拆分 : ") print(a.split([2,2],dim=1)) #把維度1按照長度[2,2]拆分,形成2個tensor, #shape(3,2)和shape(3,2) # a : # tensor([[-0.5111, 1.3557, -0.1616, -0.2014], # [-1.1011, 1.0982, -0.1794, -0.3510], # [-1.4451, -2.1550, -1.9542, -1.6998]]) # 按維度1拆分 : # (tensor([[-0.5111, 1.3557], # [-1.1011, 1.0982], # [-1.4451, -2.1550]]), tensor([[-0.1616, -0.2014], # [-0.1794, -0.3510], # [-1.9542, -1.6998]]))
chunk
chunk
可以理解為均等分的split,但是當維度長度不能被等分份數整除時,雖然不會報錯,但可能結果與預期的不一樣,建議只在可以被整除的情況下運用
import torch a=torch.randn(4,6) print("a :") print(a) #print(a.chunk(2,dim=0)) #返回2個shape(2,6)的tensor # a : # tensor([[ 0.8008, 0.6612, -0.0450, 0.3255, -0.4714, -0.2343], # [ 1.3068, -0.2587, 1.4938, 0.1859, 0.4674, 0.0086], # [ 0.4522, 0.0220, -0.2653, -0.0588, 1.1987, 0.7340], # [ 0.1547, -0.2052, -0.8919, -0.8763, -0.6897, 0.2474]]) # (tensor([[ 0.8008, 0.6612, -0.0450, 0.3255, -0.4714, -0.2343], # [ 1.3068, -0.2587, 1.4938, 0.1859, 0.4674, 0.0086]]), tensor([[ 0.4522, 0.0220, -0.2653, -0.0588, 1.1987, 0.7340], # [ 0.1547, -0.2052, -0.8919, -0.8763, -0.6897, 0.2474]])) print(a.chunk(2,dim=1)) #返回2個shape(4,3)的tensor]])) a : tensor([[-0.4875, 1.4914, 0.2244, -0.5883, -0.5951, -0.4857], [-0.1344, -0.6973, -0.2042, 2.5817, -0.7972, -0.6522], [ 1.4379, -0.1185, 0.4457, -1.1168, 1.0184, -0.5088], [-0.7692, 1.4040, -0.2799, 1.1515, 0.2329, 0.4926]]) (tensor([[-0.4875, 1.4914, 0.2244], [-0.1344, -0.6973, -0.2042], [ 1.4379, -0.1185, 0.4457], [-0.7692, 1.4040, -0.2799]]), tensor([[-0.5883, -0.5951, -0.4857], [ 2.5817, -0.7972, -0.6522], [-1.1168, 1.0184, -0.5088], [ 1.1515, 0.2329, 0.4926]]))