【Pytorch】unsqueeze()與squeeze()詳解
squeeze()主要用於對資料的維度進行壓縮或者解壓的
torch.squeeze()
對資料的維度進行壓縮
使用方式:torch.squeeze(input, dim=None, out=None)
將輸入張量形狀中的1 去除並返回。 如果輸入是形如(A×1×B×1×C×1×D),那麼輸出形狀就為: (A×B×C×D)
當給定dim時,那麼擠壓操作只在給定維度上。例如,輸入形狀為: (A×1×B),
squeeze(input, 0)
將會保持張量不變,只有用squeeze(input, 1)
,形狀會變成 (A×B)。
注意:
輸入的張量和返回的張量共用一段記憶體空間,改變了其中一個,其餘的都會被改變
引數:
input (Tensor)
– 輸入張量dim (int, optional)
– 如果給定,則input只會在給定維度擠壓out (Tensor, optional)
– 輸出張量
多維張量本質上就是一個變換,如果維度是 1 ,那麼,1 僅僅起到擴充維度的作用,而沒有其他用途,因而,在進行降維操作時,為了加快計算,是可以去掉這些 1 的維度。
import torch
m = torch.zeros( 2,1,2,1,2)
print(m.size())
# print torch.Size([2, 1, 2, 1, 2])
# 預設刪除維度為1的維數
n = torch.squeeze(m)
print(n.size())
# print torch.Size([2, 2, 2])
# 當給定dim引數值的時候,那麼擠壓操作只會給定在指定的維度上
m = torch.zeros(2,1,2,1,2)
n = torch.squeeze(m,0)
print(n.size())
# print torch.Size([2, 1, 2, 1, 2])
n = torch.squeeze(m, 1)
print(n.size())
#print torch.Size([2, 2, 1, 2])
n = torch.squeeze(m, 2)
print(n.size())
#print torch.Size([2, 1, 2, 1, 2])
n = torch.squeeze(m, 3)
print(n.size())
#print torch.Size([2, 1, 2, 2])
print('-' * 100)
p = torch.zeros(2, 1, 1)
print(p)
# tensor([[[0.]],
# [[0.]]])
print(p.numpy())
# [[[0.]]
# [[0.]]]
print(p.size())
# torch.Size([2, 1, 1])
q = torch.squeeze(p)
print(q)
# tensor([0., 0.])
print(q.numpy())
# [0. 0.]
print(q.size())
# torch.Size([2])
==總結:==這個函式主要對資料的維度進行壓縮,去掉維數為1的的維度,比如是一行或者一列這種,一個一行三列(1,3)的數去掉第一個維數為一的維度之後就變成(3)行。squeeze(a)就是將a中所有為1的維度刪掉。不為1的維度沒有影響。a.squeeze(N) 就是去掉a中指定的維數為一的維度。還有一種形式就是b=torch.squeeze(a,N) a中去掉指定的定的維數為一的維度。
torch.unsqueeze()
torch.unsqueeze(input, dim, out=None)
作用:擴充套件維度
返回一個新的張量,對輸入的既定位置插入維度 1
注意: 返回張量與輸入張量共享記憶體,所以改變其中一個的內容會改變另一個。
如果dim為負,則將會被轉化dim+input.dim()+1
- 引數:
tensor (Tensor)
– 輸入張量dim (int)
– 插入維度的索引out (Tensor, optional)
– 結果張量
import torch
x = torch.Tensor([1, 2, 3, 4]) # torch.Tensor是預設的tensor型別(torch.FlaotTensor)的簡稱。
print('-' * 50)
print(x) # tensor([1., 2., 3., 4.])
print(x.size()) # torch.Size([4])
print(x.dim()) # 1
print(x.numpy()) # [1. 2. 3. 4.]
print('-' * 50)
print(torch.unsqueeze(x, 0)) # tensor([[1., 2., 3., 4.]])
print(torch.unsqueeze(x, 0).size()) # torch.Size([1, 4])
print(torch.unsqueeze(x, 0).dim()) # 2
print(torch.unsqueeze(x, 0).numpy()) # [[1. 2. 3. 4.]]
print('-' * 50)
print(torch.unsqueeze(x, 1))
# tensor([[1.],
# [2.],
# [3.],
# [4.]])
print(torch.unsqueeze(x, 1).size()) # torch.Size([4, 1])
print(torch.unsqueeze(x, 1).dim()) # 2
print('-' * 50)
print(torch.unsqueeze(x, -1))
# tensor([[1.],
# [2.],
# [3.],
# [4.]])
print(torch.unsqueeze(x, -1).size()) # torch.Size([4, 1])
print(torch.unsqueeze(x, -1).dim()) # 2
print('-' * 50)
print(torch.unsqueeze(x, -2)) # tensor([[1., 2., 3., 4.]])
print(torch.unsqueeze(x, -2).size()) # torch.Size([1, 4])
print(torch.unsqueeze(x, -2).dim()) # 2
==總結:==給指定位置加上維數為一的維度,比如原本有個三行的資料(3),在0的位置加了一維就變成一行三列(1,3)。a.unsqueeze(N) 就是在a中指定位置N加上一個維數為1的維度。
參考:https://zhuanlan.zhihu.com/p/86763381