1. 程式人生 > 其它 >torch.max()函式

torch.max()函式

一、_, predicted = torch.max(outputs.data, 1)

那麼,這裡的 下劃線_ 表示什麼意思?

首先,torch.max()這個函式返回的是兩個值,第一個值是具體的value(我們用下劃線_表示),第二個值是value所在的index(也就是predicted)。

那麼,這個 下劃線_ 表示的就是具體的value,也就是輸出的最大值。那麼為什麼用 下劃線_,可不可以用其他的變數名稱來代替,比如x?答案自然是可以的。

那麼為什麼這裡選擇用這麼特殊的下劃線?有沒有特殊含義?這是因為我們不關心最大值是什麼,而關心最大值對應的index是什麼,所以選用下劃線代表不需要用到的變數。比如在影象分類任務中,值所對應的index就對應著相應的類別class,當我們只關心網路預測的類別是什麼,而不關心該類別的預測概率是多少時,就選擇使用下劃線_。

二、這裡的數字1表示什麼意思?

數字1其實可以寫為dim=1,這裡簡寫為1,python也可以自動識別,dim=1表示輸出所在行的最大值,若改寫成dim=0則輸出所在列的最大值。比如說測試集有10個數據,那麼訓練好的網路將會預測這10個數據,得到一個10×2的矩陣(假設是二分類問題,二分類只輸出兩個類別,所以是兩列),比如說預測結果是下面這個矩陣。這裡的數字就是,網路預測為對應類別的概率,而行代表樣本、列代表類別,所以這裡應該用dim=1,因為你需要輸出的每個樣本的預測類別。