torch中的幾個函式(min()、max()、prod()、clamp_()、where())
阿新 • • 發佈:2022-03-15
torch.min()、torch.max()、torch.prod()
這兩個函式很好理解,就是求張量中的最小值和最大值以及相乘
1.在這兩個函式中如果沒有指定維度的話,那麼預設是將張量中的所有值進行比較,輸出最大值或者最小值或是所有值相乘。
2.而當指定維度之後,會將對應維度的資料進行比較,同時輸出的有最小值以及這個最小值在對應維度的下標,或是指定維度相乘
3.使用這兩個函式對兩個張量進行比較時,輸出的是張量中每一個值對應位置的最小值,對應位置相乘
1 >>> a = torch.Tensor([[1,2,3,4],[5,6,7,8]])2 >>> b = torch.Tensor([[2,1,4,3],[6,5,8,7]]) 3 >>> torch.min(a) 4 tensor(1.) 5 6 >>> torch.min(a,dim=0) 7 torch.return_types.min( 8 values=tensor([1., 2., 3., 4.]), 9 indices=tensor([0, 0, 0, 0])) 10 11 >>> torch.min(a,b) 12 tensor([[1., 1., 3., 3.], 13 [5., 5., 7., 7.]])14 15 >>> torch.min(a[:,2:],b[:,2:]) 16 tensor([[3., 3.], 17 [7., 7.]])
torch.clamp_()
這個函式其實就是對張量進行上下限的限制,超過了指定的上限或是下限之後,該值賦值為確定的界限的值
1 >>> a = torch.Tensor([[1,2,3,4],[5,6,7,8]]) 2 >>> a.clamp_(min = 2.5,max = 6.5) 3 4 tensor([[2.5000, 2.5000, 3.0000, 4.0000], 5 [5.0000, 6.0000, 6.5000, 6.5000]])
torch.where()
函式的定義如下:
- torch.where(condition, x, y):
- condition:判斷條件
- x:若滿足條件,則取x中元素
- y:若不滿足條件,則取y中元素
1 import torch 2 # 條件 3 condition = torch.rand(3, 2) 4 print(condition) 5 # 滿足條件則取x中對應元素 6 x = torch.ones(3, 2) 7 print(x) 8 # 不滿足條件則取y中對應元素 9 y = torch.zeros(3, 2) 10 print(y) 11 # 條件判斷後的結果 12 result = torch.where(condition > 0.5, x, y) 13 print(result) 14 15 16 17 18 tensor([[0.3224, 0.5789], 19 [0.8341, 0.1673], 20 [0.1668, 0.4933]]) 21 tensor([[1., 1.], 22 [1., 1.], 23 [1., 1.]]) 24 tensor([[0., 0.], 25 [0., 0.], 26 [0., 0.]]) 27 tensor([[0., 1.], 28 [1., 0.], 29 [0., 0.]])
可以看到是對張量中的每一個值進行比較,單獨進行條件判斷,輸出張量對應的位置為判斷後對應判擇的輸出