1. 程式人生 > 其它 >【Pytorch】unsqueeze()與squeeze()詳解

【Pytorch】unsqueeze()與squeeze()詳解

技術標籤:Pytorchpythonpytorch

squeeze()主要用於對資料的維度進行壓縮或者解壓的

官方文件

torch.squeeze()

對資料的維度進行壓縮

使用方式:torch.squeeze(input, dim=None, out=None)

將輸入張量形狀中的1 去除並返回。 如果輸入是形如(A×1×B×1×C×1×D),那麼輸出形狀就為: (A×B×C×D)

當給定dim時,那麼擠壓操作只在給定維度上。例如,輸入形狀為: (A×1×B), squeeze(input, 0)

將會保持張量不變,只有用 squeeze(input, 1),形狀會變成 (A×B)。

注意:

輸入的張量和返回的張量共用一段記憶體空間,改變了其中一個,其餘的都會被改變

引數:

  • input (Tensor) – 輸入張量
  • dim (int, optional) – 如果給定,則input只會在給定維度擠壓
  • out (Tensor, optional) – 輸出張量

多維張量本質上就是一個變換,如果維度是 1 ,那麼,1 僅僅起到擴充維度的作用,而沒有其他用途,因而,在進行降維操作時,為了加快計算,是可以去掉這些 1 的維度。

import torch


m = torch.zeros(
2,1,2,1,2) print(m.size()) # print torch.Size([2, 1, 2, 1, 2]) # 預設刪除維度為1的維數 n = torch.squeeze(m) print(n.size()) # print torch.Size([2, 2, 2]) # 當給定dim引數值的時候,那麼擠壓操作只會給定在指定的維度上 m = torch.zeros(2,1,2,1,2) n = torch.squeeze(m,0) print(n.size()) # print torch.Size([2, 1, 2, 1, 2]) n = torch.squeeze(m,
1) print(n.size()) #print torch.Size([2, 2, 1, 2]) n = torch.squeeze(m, 2) print(n.size()) #print torch.Size([2, 1, 2, 1, 2]) n = torch.squeeze(m, 3) print(n.size()) #print torch.Size([2, 1, 2, 2]) print('-' * 100) p = torch.zeros(2, 1, 1) print(p) # tensor([[[0.]], # [[0.]]]) print(p.numpy()) # [[[0.]] # [[0.]]] print(p.size()) # torch.Size([2, 1, 1]) q = torch.squeeze(p) print(q) # tensor([0., 0.]) print(q.numpy()) # [0. 0.] print(q.size()) # torch.Size([2])

==總結:==這個函式主要對資料的維度進行壓縮,去掉維數為1的的維度,比如是一行或者一列這種,一個一行三列(1,3)的數去掉第一個維數為一的維度之後就變成(3)行。squeeze(a)就是將a中所有為1的維度刪掉。不為1的維度沒有影響。a.squeeze(N) 就是去掉a中指定的維數為一的維度。還有一種形式就是b=torch.squeeze(a,N) a中去掉指定的定的維數為一的維度。

torch.unsqueeze()

torch.unsqueeze(input, dim, out=None)

  • 作用:擴充套件維度

    ​ 返回一個新的張量,對輸入的既定位置插入維度 1

  • 注意: 返回張量與輸入張量共享記憶體,所以改變其中一個的內容會改變另一個。

如果dim為負,則將會被轉化dim+input.dim()+1

  • 引數:
  • tensor (Tensor) – 輸入張量
  • dim (int) – 插入維度的索引
  • out (Tensor, optional) – 結果張量
import torch

x = torch.Tensor([1, 2, 3, 4])  # torch.Tensor是預設的tensor型別(torch.FlaotTensor)的簡稱。

print('-' * 50)
print(x)  # tensor([1., 2., 3., 4.])
print(x.size())  # torch.Size([4])
print(x.dim())  # 1
print(x.numpy())  # [1. 2. 3. 4.]

print('-' * 50)
print(torch.unsqueeze(x, 0))  # tensor([[1., 2., 3., 4.]])
print(torch.unsqueeze(x, 0).size())  # torch.Size([1, 4])
print(torch.unsqueeze(x, 0).dim())  # 2
print(torch.unsqueeze(x, 0).numpy())  # [[1. 2. 3. 4.]]

print('-' * 50)
print(torch.unsqueeze(x, 1))
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])
print(torch.unsqueeze(x, 1).size())  # torch.Size([4, 1])
print(torch.unsqueeze(x, 1).dim())  # 2

print('-' * 50)
print(torch.unsqueeze(x, -1))
# tensor([[1.],
#         [2.],
#         [3.],
#         [4.]])
print(torch.unsqueeze(x, -1).size())  # torch.Size([4, 1])
print(torch.unsqueeze(x, -1).dim())  # 2

print('-' * 50)
print(torch.unsqueeze(x, -2))  # tensor([[1., 2., 3., 4.]])
print(torch.unsqueeze(x, -2).size())  # torch.Size([1, 4])
print(torch.unsqueeze(x, -2).dim())  # 2

==總結:==給指定位置加上維數為一的維度,比如原本有個三行的資料(3),在0的位置加了一維就變成一行三列(1,3)。a.unsqueeze(N) 就是在a中指定位置N加上一個維數為1的維度。

參考:https://zhuanlan.zhihu.com/p/86763381