torch.argmax和argmin返回值
在進行深度學習張量計算時,經常要獲取張量在某個維度的最大值和最小值,以及這些值的位置。如果只需要知道位置,則torch.argmax和torch.argmin函式便可以實現。
Torch.argmax(input, dim=None, keepdim=False
):返回指定維度最大值的序號。
有時候返回的值比較難理解,所以這裡直接放example以幫助理解:
1 import torch 2 3 t = torch.tensor([[1,2],[3,4],[2,8]]) 4 5 print(torch.argmax(t,0)) 6 7 8 g = torch.tensor([[[1,2,3],[2,3,4],[5,6,7]], [[3,4,5],[7,6,5],[5,4,3]], [[8,9,0],9 [2,8,4],[7,5,3]]]) 10 print(g) 11 print(torch.argmax(g,0))
先從簡單的2維張量來看,t 是一個2維張量,大小為(3,2)。t 為,此時我們使dim=0,意思使求第0維的(即(3,2)中的3行)中的最大值的序號,所以固定行,直接看列,第一列中3最大,故得到值1,第2列中8最大,故得到值2。最終的結果為 tensor([1,2])
再來看一個3維張量g ,tensor([[[1, 2, 3],
[2, 3, 4],
[5, 6, 7]],
[[3, 4, 5],
[7, 6, 5],
[5, 4, 3]],
[[8, 9, 0],
[2, 8, 4],
[7, 5, 3]]]),其大小為(3,3,3) 其中我們希望在dim=0的維度中求最大值的序號,則固定第一個維度,第一個維度為channel,則每個channel中對應位置進行比較。
比如每個channel中的(0,0)比較,1<3<8,所以得到的值為2;(0,1)比較,2<4<9,依然得到2,....以此類推。最終得到結果tensor([[2, 2, 1],[1, 2, 1],[2, 0, 0]])。