1. 程式人生 > 其它 >理解torch.scatter_()

理解torch.scatter_()

技術標籤:pytorchpytorch深度學習

理解torch.scatter_()

官方文件

scatter_(dim, index, src): 將src中所有的值分散到self 中,填法是按照index中所指示的索引來填入。

dim用來指定index進行對映的維度,其他維度則保持不變。

Note: src可以是一個scalar。在這種情況下,該函式的操作是根據index來散佈單個值。

當dim=0

dim=0,意味著在src按照index行索引的指示來進行散射,換言之,srcj列按照index

j列中的值散射到selfj列中。(表述還是很繞,看例子吧)

以下是官方的例子:

>>
> x = torch.rand(2, 5) >>> x tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004], [ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]]) >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004], [ 0.0000, 0.2908, 0.0000, 0.4152
, 0.0000], [ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])

因為dim=0,所以是列對映到列,散射操作可以按列依次進行。

第一列:
在這裡插入圖片描述
第二列:
在這裡插入圖片描述
直到最後一列:
在這裡插入圖片描述

當dim = 1

dim=1,意味著在src按照index列索引的指示來進行散射,換言之,srci行按照index

i行中的值散射到selfi列中。

>>> src = torch.from_numpy(np.arange(1, 11)).float().view(2, 5)
>>> input_tensor = torch.zeros(3, 5)
>>
> index_tensor = torch.tensor([[3, 0, 2, 1, 4], [2, 0, 1, 3, 1]]) >>> dim = 1 >>> input_tensor.scatter_(dim, index_tensor, src) tensor([[ 2., 4., 3., 1., 5.], [ 7., 10., 6., 9., 0.], [ 0., 0., 0., 0., 0.]])

散射操作前:
在這裡插入圖片描述
更新第一行:
在這裡插入圖片描述
更新第二行, 可以看到index中出現重複的對映索引值1,因此後一個會把前一個覆蓋:

8和10都是對映到col1,可以看到10把8給覆蓋了。

當src是scalar

>>> input_tensor = torch.from_numpy(np.arange(1, 16)).float().view(3, 5) # dim is 2
>>> # unsqueeze to have dim = 2
>>> index_tensor = torch.tensor([4, 0, 1]).unsqueeze(1) 
>>> src = 0
>>> dim = 1
>>> input_tensor.scatter_(dim, index_tensor, src)
tensor([[ 1.,  2.,  3.,  4.,  0.],
        [ 0.,  7.,  8.,  9., 10.],
        [11.,  0., 13., 14., 15.]])

Note:

  • index的維度要和輸入張量的維度保持一致。同時index要在相同維度上的尺度不能大於輸入張量。

  • src是標量時,我們實際上使用的是廣播版本,其形狀與index張量相同。
    在這裡插入圖片描述

程式碼實操

該函式最常用的場景是把標量的標籤轉換為one-hot編碼

batch_size = 4
class_num = 5
labels = torch.tensor([4, 0, 1, 2]).unsqueeze(1)
one_hot = torch.zeros(batch_size, class_num)
dim=1; index_tensor = labels; src=1
one_hot.scatter_(dim, index_tensor, src)
print(one_hot)
> tensor([[0., 0., 0., 0., 1.],
        [1., 0., 0., 0., 0.],
        [0., 1., 0., 0., 0.],
        [0., 0., 1., 0., 0.]])

References:

  1. Understand torch.scatter_()

  2. https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_