PyTorch筆記之scatter()函式的使用
阿新 • • 發佈:2020-02-13
scatter() 和 scatter_() 的作用是一樣的,只不過 scatter() 不會直接修改原來的 Tensor,而 scatter_() 會
PyTorch 中,一般函式加下劃線代表直接在原來的 Tensor 上修改
scatter(dim,index,src) 的引數有 3 個
- dim:沿著哪個維度進行索引
- index:用來 scatter 的元素索引
- src:用來 scatter 的源元素,可以是一個標量或一個張量
這個 scatter可以理解成放置元素或者修改元素
簡單說就是通過一個張量 src 來修改另一個張量,哪個元素需要修改、用 src 中的哪個元素來修改由 dim 和 index 決定
官方文件給出了 3維張量 的具體操作說明,如下所示
self[index[i][j][k]][j][k] = src[i][j][k] # if dim == 0 self[i][index[i][j][k]][k] = src[i][j][k] # if dim == 1 self[i][j][index[i][j][k]] = src[i][j][k] # if dim == 2
exmaple:
x = torch.rand(2,5) #tensor([[0.1940,0.3340,0.8184,0.4269,0.5945],# [0.2078,0.5978,0.0074,0.0943,0.0266]]) torch.zeros(3,5).scatter_(0,torch.tensor([[0,1,2,0],[2,2]]),x) #tensor([[0.1940,# [0.0000,0.0000,0.0000],0.0266]])
具體地說,我們的 index 是torch.tensor([[0,2]]),一個二維張量,下面用圖簡單說明
我們是 2維 張量,一開始進行 $self[index[0][0]][0]$,其中 $index[0][0]$ 的值是0,所以執行 $self[0][0] = x[0][0] = 0.1940$
$self[index[i][j]][j] = src[i][j] $
再比如$self[index[1][0]][0]$,其中 $index[1][0]$ 的值是2,所以執行 $self[2][0] = x[1][0] = 0.2078$
src 除了可以是張量外,也可以是一個標量
example:
torch.zeros(3,7) #tensor([[7.,7.,7.],# [0.,0.,0.],# [7.,7.]]
scatter()一般可以用來對標籤進行 one-hot 編碼,這就是一個典型的用標量來修改張量的一個例子
example:
class_num = 10 batch_size = 4 label = torch.LongTensor(batch_size,1).random_() % class_num #tensor([[6],# [0],# [3],# [2]]) torch.zeros(batch_size,class_num).scatter_(1,label,1) #tensor([[0.,1.,# [1.,0.]])
以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支援我們。