argmax對softmax的輸出返回概率最大的類別
阿新 • • 發佈:2020-12-18
假設資料如下,每行是softmax輸出得到的概率,我需要找到最大的概率返回類別,可以使用argmax函式
(1)注意使用argmax函式時,需要將資料轉換為tensor型別,否則報錯 argmax(): argument 'input' (position 1) must be Tensor, not numpy.ndarray (2)torch.argmax函式需要傳遞dim引數,dim=1就是在行上求 index = torch.argmax(data_pre, dim=1)
import numpy as np import pandas as pd import torch data_pre = np.loadtxt('./pred.txt') data_pre = torch.tensor(data_pre) index = torch.argmax(data_pre, dim=1) index = np.array(index) np.savetxt('./cluster.txt', (index))