1. 程式人生 > 實用技巧 >torch 中的torch.squeeze()和torch.unsqueeze()

torch 中的torch.squeeze()和torch.unsqueeze()

1. torch.squeeze(input, dim=None, out=None)

input是輸入的引數,dim是指定要合併維度為1的所在維度

當dim=0時原樣輸出,當dim=1時合併維度為1的行,dim=2 合併維度為1的列,當所在的行和列的維度不為1時原樣輸出,

例如:

import torch as t

a=t.araneg(8).view(4,1,2)#生成四個一行兩列的tensor

t.squeeze(a,dim=0)#原樣輸出tensor

結果為:

tensor([[[0, 1]],

        [[2, 3]],

        [[4, 5]],

        [[6, 7]]])
當dim=1時,由於行所在的維度為1,因此合併行,生成4行兩列的tensor
t.squeeze(a,dim=1)
結果為:
tensor([[0, 1],
        [2, 3],
        [4, 5],
        [6, 7]])
當dim=2時,由於列的維度為2,所以原樣輸出
t.squeeze(a,dim=2)
結果為:
tensor([[[0, 1]],

        [[2, 3]],

        [[4, 5]],

        [[6, 7]]])
但是我們將原來的tensor換成2個4行1列的tensor,當dim=2時,將會生成2行4列的tensor
import torch as t

a=t.arange(8).view(2,4,1)
t.squeeze(a,dim=2)
輸出的結果為:
tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

2. torch.unsqueeze(input, dim, out=None)

插入一個維度唯一的維度

dim=0原樣輸出,dim=1在山上插入維度為1 的維度,dim=2在列上插入維度為1 的維度

比如某一tensor為(2,4)

當dim=1時變成(2,1,4)兩個1行4列的tensor

當dim=2時變成(2,4,1)變成兩個4行1列的tensor

import torch as t
a=t.arange(8).view(2,4)

a

結果為:

tensor([[0, 1, 2, 3],
        [4, 5, 6, 7]])

1)t.unsqueeze(a,dim=0)#原樣輸出

結果為:

tensor([[[0, 1, 2, 3],
         [4, 5, 6, 7]]])

2)t.unsqueeze(a,dim=1)#在行的維度加1,變成2個1行4列,即(2,1,4)
tensor([[[0, 1, 2, 3]],

        [[4, 5, 6, 7]]])
3)t.unqueeze(a,dim=2)#在列的維度加1 變成2個4行1列,即(2,4,1)
tensor([[[0],
         [1],
         [2],
         [3]],

        [[4],
         [5],
         [6],
         [7]]])