1. 程式人生 > 其它 >`torch.gather`理解

`torch.gather`理解

official link

函式定義

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

沿著dim指定的軸聚集tensor的值。返回的是原資料的複製,修改返回值不會修改原tensor。
引數:

  • input: 原tensor
  • dim: 待索引的軸
  • index: 待聚集元素的索引

直觀一點說就是:獲取tensor中指定dim和指定index的資料,index可以不連續,dim只能指定為單個軸。

圖解

可參照知乎 - 圖解PyTorch中的torch.gather函式

程式碼實戰

問題:torch.gather

和中括號索引有啥區別嗎?換言之,有啥功能torch.gather可以實現而中括號索引做不到的嗎?
看幾個例子(注意下文語言描述中index也是從0開始)

  • 功能1:一維陣列順序索引,x是長度為8的tensor,取出其中2,3,4,5元素
import torch

x = torch.rand(8)

# 中括號索引
x1 = x[2:6]

# torch.gather
idx = torch.arange(2, 6)
x2 = x.gather(dim=0, index=idx)

print(x1.equal(x2))
print(x2)

# >> True
# >> tensor([0.2986, 0.9610, 0.5088, 0.5334])
  • 功能2:一維陣列亂序索引,x是長度為8的tensor,取出其中第3,2,1,4元素
import torch

x = torch.rand(8)

# 中括號索引
x1 = x[[3, 2, 1, 4]]

# torch.gather
idx = torch.tensor([3, 2, 1, 4])
x2 = x.gather(dim=0, index=idx)

print(x1.equal(x2))
print(x2)

# >> True
# >> tensor([0.2344, 0.3249, 0.6847, 0.0074])
  • 功能3:二維陣列,取出左上角矩陣塊 (32, 32) -> (16, 16)
import torch

img = torch.rand(32, 32)

# 中括號索引
x1 = img[:16, :16]

# torch.gather
# torch.gather cannot do that, because it only gather from single axis

  • 功能4:二維陣列,取出其中一維陣列的前5個元素
import torch

x = torch.rand(10, 8)  # 理解為10個長度為8的一維陣列

# 中括號索引
x1 = x[:, :5]

# torch.gather
idx = torch.arange(5)
x2 = x.gather(dim=1, index=idx.repeat(10, 1))

print(x1.equal(x2))
print(x2)

# >> True
# >> tensor([[0.6874, 0.0678, 0.9632, 0.1192, 0.6583],
        [0.4384, 0.2263, 0.7262, 0.1914, 0.5774],
        [0.1143, 0.4723, 0.2176, 0.6535, 0.3592],
        [0.6786, 0.9794, 0.3704, 0.2499, 0.3386],
        [0.2688, 0.0812, 0.1744, 0.7484, 0.4401],
        [0.1044, 0.1304, 0.1224, 0.7055, 0.8579],
        [0.5830, 0.8599, 0.2381, 0.0195, 0.0563],
        [0.9367, 0.5019, 0.7067, 0.4395, 0.5474],
        [0.6782, 0.0398, 0.1375, 0.7691, 0.2615],
        [0.6938, 0.3334, 0.8047, 0.6111, 0.0039]])
  • 功能5:二維陣列shape=(3, 6),取出其中每個一維陣列2個元素,陣列1對應位置0、1,陣列2對應位置2、3,陣列3對應位置4、5
import torch

x = torch.rand(3, 6)  # 理解為10個長度為8的一維陣列

# 中括號索引???

# torch.gather

idx = torch.tensor([[0, 1], [2, 3], [4, 5]])
# idx = torch.arange(6).reshape(2, 3)

x2 = x.gather(dim=1, index=idx)

print(x)
print(idx)
print(x2)

# >> tensor([[0.9728, 0.8356, 0.0183, 0.7821, 0.8426, 0.1422],
# >>         [0.3964, 0.4667, 0.3980, 0.3452, 0.3055, 0.8527],
# >>         [0.2162, 0.5601, 0.4261, 0.1134, 0.0281, 0.4682]])
# >> tensor([[0, 1],
# >>         [2, 3],
# >>         [4, 5]])
# >> tensor([[0.9728, 0.8356],
# >>         [0.3980, 0.3452],
# >>         [0.0281, 0.4682]])

總結

torch.gather僅能在一個維度上索引,對於一批陣列,可以檢索不同位置的元素;中括號索引可以在多個維度上操作,但對於一批陣列,只能獲取相同位置的元素。