1. 程式人生 > 實用技巧 >torch.argmax和argmin返回值

torch.argmax和argmin返回值

  在進行深度學習張量計算時,經常要獲取張量在某個維度的最大值和最小值,以及這些值的位置。如果只需要知道位置,則torch.argmax和torch.argmin函式便可以實現。

Torch.argmax(input, dim=None, keepdim=False):返回指定維度最大值的序號。

  有時候返回的值比較難理解,所以這裡直接放example以幫助理解:

 1 import torch
 2 
 3 t = torch.tensor([[1,2],[3,4],[2,8]])
 4 
 5 print(torch.argmax(t,0))
 6 
 7 
 8 g = torch.tensor([[[1,2,3],[2,3,4],[5,6,7]], [[3,4,5],[7,6,5],[5,4,3]], [[8,9,0],        
9 [2,8,4],[7,5,3]]]) 10 print(g) 11 print(torch.argmax(g,0))

先從簡單的2維張量來看,t 是一個2維張量,大小為(3,2)。t 為,此時我們使dim=0,意思使求第0維的(即(3,2)中的3行)中的最大值的序號,所以固定行,直接看列,第一列中3最大,故得到值1,第2列中8最大,故得到值2。最終的結果為 tensor([1,2])


再來看一個3維張量g ,tensor([[[1, 2, 3],

              [2, 3, 4],
              [5, 6, 7]],

              [[3, 4, 5],
              [7, 6, 5],
              [5, 4, 3]],

              [[8, 9, 0],
              [2, 8, 4],
              [7, 5, 3]]]),其大小為(3,3,3) 其中我們希望在dim=0的維度中求最大值的序號,則固定第一個維度,第一個維度為channel,則每個channel中對應位置進行比較。

比如每個channel中的(0,0)比較,1<3<8,所以得到的值為2;(0,1)比較,2<4<9,依然得到2,....以此類推。最終得到結果tensor([[2, 2, 1],[1, 2, 1],[2, 0, 0]])。