`torch.gather`理解
阿新 • • 發佈:2022-03-23
函式定義
torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor
沿著dim指定的軸聚集tensor的值。返回的是原資料的複製,修改返回值不會修改原tensor。
引數:
- input: 原tensor
- dim: 待索引的軸
- index: 待聚集元素的索引
直觀一點說就是:獲取tensor中指定dim和指定index的資料,index可以不連續,dim只能指定為單個軸。
圖解
可參照知乎 - 圖解PyTorch中的torch.gather函式
程式碼實戰
問題:torch.gather
torch.gather
可以實現而中括號索引做不到的嗎?看幾個例子(注意下文語言描述中index也是從0開始)
- 功能1:一維陣列順序索引,x是長度為8的tensor,取出其中2,3,4,5元素
import torch x = torch.rand(8) # 中括號索引 x1 = x[2:6] # torch.gather idx = torch.arange(2, 6) x2 = x.gather(dim=0, index=idx) print(x1.equal(x2)) print(x2) # >> True # >> tensor([0.2986, 0.9610, 0.5088, 0.5334])
- 功能2:一維陣列亂序索引,x是長度為8的tensor,取出其中第3,2,1,4元素
import torch
x = torch.rand(8)
# 中括號索引
x1 = x[[3, 2, 1, 4]]
# torch.gather
idx = torch.tensor([3, 2, 1, 4])
x2 = x.gather(dim=0, index=idx)
print(x1.equal(x2))
print(x2)
# >> True
# >> tensor([0.2344, 0.3249, 0.6847, 0.0074])
- 功能3:二維陣列,取出左上角矩陣塊 (32, 32) -> (16, 16)
import torch
img = torch.rand(32, 32)
# 中括號索引
x1 = img[:16, :16]
# torch.gather
# torch.gather cannot do that, because it only gather from single axis
- 功能4:二維陣列,取出其中一維陣列的前5個元素
import torch
x = torch.rand(10, 8) # 理解為10個長度為8的一維陣列
# 中括號索引
x1 = x[:, :5]
# torch.gather
idx = torch.arange(5)
x2 = x.gather(dim=1, index=idx.repeat(10, 1))
print(x1.equal(x2))
print(x2)
# >> True
# >> tensor([[0.6874, 0.0678, 0.9632, 0.1192, 0.6583],
[0.4384, 0.2263, 0.7262, 0.1914, 0.5774],
[0.1143, 0.4723, 0.2176, 0.6535, 0.3592],
[0.6786, 0.9794, 0.3704, 0.2499, 0.3386],
[0.2688, 0.0812, 0.1744, 0.7484, 0.4401],
[0.1044, 0.1304, 0.1224, 0.7055, 0.8579],
[0.5830, 0.8599, 0.2381, 0.0195, 0.0563],
[0.9367, 0.5019, 0.7067, 0.4395, 0.5474],
[0.6782, 0.0398, 0.1375, 0.7691, 0.2615],
[0.6938, 0.3334, 0.8047, 0.6111, 0.0039]])
- 功能5:二維陣列shape=(3, 6),取出其中每個一維陣列2個元素,陣列1對應位置0、1,陣列2對應位置2、3,陣列3對應位置4、5
import torch
x = torch.rand(3, 6) # 理解為10個長度為8的一維陣列
# 中括號索引???
# torch.gather
idx = torch.tensor([[0, 1], [2, 3], [4, 5]])
# idx = torch.arange(6).reshape(2, 3)
x2 = x.gather(dim=1, index=idx)
print(x)
print(idx)
print(x2)
# >> tensor([[0.9728, 0.8356, 0.0183, 0.7821, 0.8426, 0.1422],
# >> [0.3964, 0.4667, 0.3980, 0.3452, 0.3055, 0.8527],
# >> [0.2162, 0.5601, 0.4261, 0.1134, 0.0281, 0.4682]])
# >> tensor([[0, 1],
# >> [2, 3],
# >> [4, 5]])
# >> tensor([[0.9728, 0.8356],
# >> [0.3980, 0.3452],
# >> [0.0281, 0.4682]])
總結
torch.gather
僅能在一個維度上索引,對於一批陣列,可以檢索不同位置的元素;中括號索引可以在多個維度上操作,但對於一批陣列,只能獲取相同位置的元素。