pytorch 多分類問題,計算百分比
阿新 • • 發佈:2018-11-14
二分類或分類問題,網路輸出為二維矩陣:批次x幾分類,最大的為當前分類,標籤為one-hot型的二維矩陣:批次x幾分類
計算百分比有numpy和pytorch兩種實現方案實現,都是根據索引計算百分比,以下為具體二分類實現過程。
pytorch
out = torch.Tensor([[0,3], [2,3], [1,0], [3,4]]) cond = torch.Tensor([[1,0], [0,1], [1,0], [1,0]]) persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double()) print(persent)
numpy
out = [[0, 3],
[2, 3],
[1, 0],
[3, 4]]
cond = [[1, 0],
[0, 1],
[1, 0],
[1, 0]]
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)