【PyTorch】tensor.scatter
阿新 • • 發佈:2020-08-22
【PyTorch】scatter
引數:
- dim (int) – the axis along which to index
- index (LongTensor) – the indices of elements to scatter, can be either empty or the same size of src. When empty, the operation returns identity
- src (Tensor) – the source element(s) to scatter, incase value is not specified
- value (float
官網例子:
第三個引數為張量時:
>>> x = torch.rand(2, 5) >>> x tensor([[ 0.3992, 0.2908, 0.9044, 0.4850, 0.6004], [ 0.5735, 0.9006, 0.6797, 0.4152, 0.1732]]) >>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), x) tensor([[ 0.3992, 0.9006, 0.6797, 0.4850, 0.6004], [ 0.0000, 0.2908, 0.0000, 0.4152, 0.0000], [ 0.5735, 0.0000, 0.9044, 0.0000, 0.1732]])
第三個引數為標量時:
>>> z = torch.zeros(2, 4).scatter_(1, torch.tensor([[2], [3]]), 1.23)
>>> z
tensor([[ 0.0000, 0.0000, 1.2300, 0.0000],
[ 0.0000, 0.0000, 0.0000, 1.2300]])
又一個栗子:
dim = 0
>>> torch.zeros(3, 5).scatter_(0, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7) tensor([[7., 7., 7., 7., 7.], [0., 7., 0., 7., 0.], [7., 0., 7., 0., 7.]])
dim = 1
>>> torch.zeros(3, 5).scatter_(1, torch.tensor([[0, 1, 2, 0, 0], [2, 0, 0, 1, 2]]), 7)
tensor([[7., 7., 7., 0., 0.],
[7., 7., 7., 0., 0.],
[0., 0., 0., 0., 0.]])