理解torch.scatter_()
阿新 • • 發佈:2020-12-23
理解torch.scatter_()
官方文件
scatter_
(dim, index, src): 將src中所有的值分散到self
中,填法是按照index
中所指示的索引來填入。
dim
用來指定index
進行對映的維度,其他維度則保持不變。
Note: src
可以是一個scalar。在這種情況下,該函式的操作是根據index
來散佈單個值。
當dim=0
dim=0,意味著在src
按照index
行索引的指示來進行散射,換言之,src
的j
列按照index
的j
列中的值散射到self
的j
列中。(表述還是很繞,看例子吧)
以下是官方的例子:
>> > x = torch.rand(2, 5)
>>> x
tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004],
[ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]])
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x)
tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004],
[ 0.0000, 0.2908, 0.0000, 0.4152 , 0.0000],
[ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
因為dim=0,所以是列對映到列,散射操作可以按列依次進行。
第一列:
第二列:
直到最後一列:
當dim = 1
dim=1,意味著在src
按照index
列索引的指示來進行散射,換言之,src
的i
行按照index
的i
行中的值散射到self
的i
列中。
>>> src = torch.from_numpy(np.arange(1, 11)).float().view(2, 5)
>>> input_tensor = torch.zeros(3, 5)
>> > index_tensor = torch.tensor([[3, 0, 2, 1, 4], [2, 0, 1, 3, 1]])
>>> dim = 1
>>> input_tensor.scatter_(dim, index_tensor, src)
tensor([[ 2., 4., 3., 1., 5.],
[ 7., 10., 6., 9., 0.],
[ 0., 0., 0., 0., 0.]])
散射操作前:
更新第一行:
更新第二行, 可以看到index
中出現重複的對映索引值1
,因此後一個會把前一個覆蓋:
8和10都是對映到col1
,可以看到10把8給覆蓋了。
當src是scalar
>>> input_tensor = torch.from_numpy(np.arange(1, 16)).float().view(3, 5) # dim is 2
>>> # unsqueeze to have dim = 2
>>> index_tensor = torch.tensor([4, 0, 1]).unsqueeze(1)
>>> src = 0
>>> dim = 1
>>> input_tensor.scatter_(dim, index_tensor, src)
tensor([[ 1., 2., 3., 4., 0.],
[ 0., 7., 8., 9., 10.],
[11., 0., 13., 14., 15.]])
Note:
-
index
的維度要和輸入張量的維度保持一致。同時index
要在相同維度上的尺度不能大於輸入張量。 -
當
src
是標量時,我們實際上使用的是廣播版本,其形狀與index
張量相同。
程式碼實操
該函式最常用的場景是把標量的標籤轉換為one-hot編碼
batch_size = 4
class_num = 5
labels = torch.tensor([4, 0, 1, 2]).unsqueeze(1)
one_hot = torch.zeros(batch_size, class_num)
dim=1; index_tensor = labels; src=1
one_hot.scatter_(dim, index_tensor, src)
print(one_hot)
> tensor([[0., 0., 0., 0., 1.],
[1., 0., 0., 0., 0.],
[0., 1., 0., 0., 0.],
[0., 0., 1., 0., 0.]])
References:
-
https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_