torch 中的torch.squeeze()和torch.unsqueeze()
阿新 • • 發佈:2020-09-10
1. torch.squeeze(input, dim=None, out=None)
input是輸入的引數,dim是指定要合併維度為1的所在維度
當dim=0時原樣輸出,當dim=1時合併維度為1的行,dim=2 合併維度為1的列,當所在的行和列的維度不為1時原樣輸出,
例如:
import torch as t
a=t.araneg(8).view(4,1,2)#生成四個一行兩列的tensor
t.squeeze(a,dim=0)#原樣輸出tensor
結果為:
tensor([[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]]])當dim=1時,由於行所在的維度為1,因此合併行,生成4行兩列的tensor
t.squeeze(a,dim=1)
結果為:
tensor([[0, 1], [2, 3], [4, 5], [6, 7]])
當dim=2時,由於列的維度為2,所以原樣輸出
t.squeeze(a,dim=2)
結果為:
tensor([[[0, 1]], [[2, 3]], [[4, 5]], [[6, 7]]])
但是我們將原來的tensor換成2個4行1列的tensor,當dim=2時,將會生成2行4列的tensor
import torch as t
a=t.arange(8).view(2,4,1)
t.squeeze(a,dim=2)
輸出的結果為:
tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])
2. torch.unsqueeze(input, dim, out=None)
插入一個維度唯一的維度
dim=0原樣輸出,dim=1在山上插入維度為1 的維度,dim=2在列上插入維度為1 的維度
比如某一tensor為(2,4)
當dim=1時變成(2,1,4)兩個1行4列的tensor
當dim=2時變成(2,4,1)變成兩個4行1列的tensor
import torch as t
a=t.arange(8).view(2,4)
a
結果為:
tensor([[0, 1, 2, 3],
[4, 5, 6, 7]])
1)t.unsqueeze(a,dim=0)#原樣輸出
結果為:
tensor([[[0, 1, 2, 3],
[4, 5, 6, 7]]])
2)t.unsqueeze(a,dim=1)#在行的維度加1,變成2個1行4列,即(2,1,4)
tensor([[[0, 1, 2, 3]], [[4, 5, 6, 7]]])
3)t.unqueeze(a,dim=2)#在列的維度加1 變成2個4行1列,即(2,4,1)
tensor([[[0], [1], [2], [3]], [[4], [5], [6], [7]]])