1. 程式人生 > >pytorch 多分類問題,計算百分比

pytorch 多分類問題,計算百分比

二分類或分類問題,網路輸出為二維矩陣:批次x幾分類,最大的為當前分類,標籤為one-hot型的二維矩陣:批次x幾分類

計算百分比有numpy和pytorch兩種實現方案實現,都是根據索引計算百分比,以下為具體二分類實現過程。

pytorch

out = torch.Tensor([[0,3],
                    [2,3],
                    [1,0],
                    [3,4]])
cond = torch.Tensor([[1,0],
                     [0,1],
                     [1,0],
                     [1,0]])

persent = torch.mean(torch.eq(torch.argmax(out, dim=1), torch.argmax(cond, dim=1)).double())
print(persent)

numpy

out = [[0, 3],
       [2, 3],
       [1, 0],
       [3, 4]]
cond = [[1, 0],
        [0, 1],
        [1, 0],
        [1, 0]] 
a = np.argmax(out,axis=1)
b = np.argmax(cond, axis=1)
persent = np.mean(np.equal(a, b) + 0)
# persent = np.mean(a==b + 0)
print(persent)