1. 程式人生 > 其它 >torch中的幾個函式(min()、max()、prod()、clamp_()、where())

torch中的幾個函式(min()、max()、prod()、clamp_()、where())

torch.min()、torch.max()、torch.prod() 

   這兩個函式很好理解,就是求張量中的最小值和最大值以及相乘

    1.在這兩個函式中如果沒有指定維度的話,那麼預設是將張量中的所有值進行比較,輸出最大值或者最小值或是所有值相乘。

    2.而當指定維度之後,會將對應維度的資料進行比較,同時輸出的有最小值以及這個最小值在對應維度的下標,或是指定維度相乘

    3.使用這兩個函式對兩個張量進行比較時,輸出的是張量中每一個值對應位置的最小值,對應位置相乘

 1 >>> a = torch.Tensor([[1,2,3,4],[5,6,7,8]])
2 >>> b = torch.Tensor([[2,1,4,3],[6,5,8,7]]) 3 >>> torch.min(a) 4 tensor(1.) 5 6 >>> torch.min(a,dim=0) 7 torch.return_types.min( 8 values=tensor([1., 2., 3., 4.]), 9 indices=tensor([0, 0, 0, 0])) 10 11 >>> torch.min(a,b) 12 tensor([[1., 1., 3., 3.], 13 [5., 5., 7., 7.]])
14 15 >>> torch.min(a[:,2:],b[:,2:]) 16 tensor([[3., 3.], 17 [7., 7.]])

torch.clamp_()

  這個函式其實就是對張量進行上下限的限制,超過了指定的上限或是下限之後,該值賦值為確定的界限的值

 

1 >>> a = torch.Tensor([[1,2,3,4],[5,6,7,8]])
2 >>> a.clamp_(min = 2.5,max = 6.5)
3 
4 tensor([[2.5000, 2.5000, 3.0000, 4.0000],
5         [5.0000, 6.0000, 6.5000, 6.5000]])

 

torch.where()

  函式的定義如下:

  1. torch.where(condition, x, y):
  2.  condition:判斷條件
  3.  x:若滿足條件,則取x中元素
  4.  y:若不滿足條件,則取y中元素
 1 import torch
 2 # 條件
 3 condition = torch.rand(3, 2)
 4 print(condition)
 5 # 滿足條件則取x中對應元素
 6 x = torch.ones(3, 2)
 7 print(x)
 8 # 不滿足條件則取y中對應元素
 9 y = torch.zeros(3, 2)
10 print(y)
11 # 條件判斷後的結果
12 result = torch.where(condition > 0.5, x, y)
13 print(result)
14 
15 
16 
17 
18 tensor([[0.3224, 0.5789],
19         [0.8341, 0.1673],
20         [0.1668, 0.4933]])
21 tensor([[1., 1.],
22         [1., 1.],
23         [1., 1.]])
24 tensor([[0., 0.],
25         [0., 0.],
26         [0., 0.]])
27 tensor([[0., 1.],
28         [1., 0.],
29         [0., 0.]])

  可以看到是對張量中的每一個值進行比較,單獨進行條件判斷,輸出張量對應的位置為判斷後對應判擇的輸出