1. 程式人生 > 實用技巧 >Tensor的索引與變形-04

Tensor的索引與變形-04

1. 索引操作與NumPy非常類似, 主要包含下標索引、 表示式索引、 使用torch.where()Tensor.clamp()的選擇性索引。

 1 import torch 
 2 
 3 a = torch.Tensor([[0,1],[2,3]])
 4 print(a, a.size())
 5 >> tensor([[0., 1.],
 6         [2., 3.]]) torch.Size([2, 2])
 7 
 8 # 根據下標進行索引
 9 print(a[1])
10 >> tensor([2., 3.])
11 print(a[0,1])
12 >> tensor(1.) 13 14 # 選擇a中大於0的元素, 返回和a相同大小的Tensor, 符合條件的置1, 否則置0 15 b = a>0 16 print(b) 17 >> tensor([[False, True], 18 [ True, True]]) 19 20 # 選擇符合條件的元素並返回, 等價於torch.masked_select(a, a>0) 21 c= a[a>0] 22 print(c) 23 >> tensor([1., 2., 3.]) 24 25 # 選擇非0元素的座標, 並返回 26
d = torch.nonzero(a) 27 print(d) 28 >> tensor([[0, 1], 29 [1, 0], 30 [1, 1]]) 31 32 # torch.where(condition, x, y), 滿足condition的位置輸出x, 否則輸出y 33 e = torch.where(a>1, torch.full_like(a,1), a) 34 print(e) 35 >> tensor([[0., 1.], 36 [1., 1.]]) 37 38 # 對Tensor元素進行限制可以使用clamp()函式, 示例如下, 限制最小值為1, 最大值為2
39 f = a.clamp(1, 2) 40 print(f) 41 >> tensor([[1., 1.], 42 [2., 2.]])
View Code

2.