1. 程式人生 > 程式設計 >PyTorch筆記之scatter()函式的使用

PyTorch筆記之scatter()函式的使用

scatter() 和 scatter_() 的作用是一樣的,只不過 scatter() 不會直接修改原來的 Tensor,而 scatter_() 會

PyTorch 中,一般函式加下劃線代表直接在原來的 Tensor 上修改

scatter(dim,index,src) 的引數有 3 個

  • dim:沿著哪個維度進行索引
  • index:用來 scatter 的元素索引
  • src:用來 scatter 的源元素,可以是一個標量或一個張量

這個 scatter可以理解成放置元素或者修改元素

簡單說就是通過一個張量 src 來修改另一個張量,哪個元素需要修改、用 src 中的哪個元素來修改由 dim 和 index 決定

官方文件給出了 3維張量 的具體操作說明,如下所示

self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0
self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1
self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2

exmaple:

x = torch.rand(2,5)

#tensor([[0.1940,0.3340,0.8184,0.4269,0.5945],#    [0.2078,0.5978,0.0074,0.0943,0.0266]])

torch.zeros(3,5).scatter_(0,torch.tensor([[0,1,2,0],[2,2]]),x)

#tensor([[0.1940,#    [0.0000,0.0000,0.0000],0.0266]])

具體地說,我們的 index 是torch.tensor([[0,2]]),一個二維張量,下面用圖簡單說明

我們是 2維 張量,一開始進行 $self[index[0][0]][0]$,其中 $index[0][0]$ 的值是0,所以執行 $self[0][0] = x[0][0] = 0.1940$

$self[index[i][j]][j] = src[i][j] $

PyTorch筆記之scatter()函式的使用

再比如$self[index[1][0]][0]$,其中 $index[1][0]$ 的值是2,所以執行 $self[2][0] = x[1][0] = 0.2078$

PyTorch筆記之scatter()函式的使用

src 除了可以是張量外,也可以是一個標量

example:

torch.zeros(3,7)

#tensor([[7.,7.,7.],#    [0.,0.,0.],#    [7.,7.]]

scatter()一般可以用來對標籤進行 one-hot 編碼,這就是一個典型的用標量來修改張量的一個例子

example:

class_num = 10
batch_size = 4
label = torch.LongTensor(batch_size,1).random_() % class_num
#tensor([[6],#    [0],#    [3],#    [2]])
torch.zeros(batch_size,class_num).scatter_(1,label,1)
#tensor([[0.,1.,#    [1.,0.]])

以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支援我們。