1. 程式人生 > >pytorch訓練ImageNet筆記(一)--accuracy

pytorch訓練ImageNet筆記(一)--accuracy

一:準確度的計算

# 計算準確度
def accuracy(output, target, topk=(1,5)):
    """Computes the [email protected] for the specified values of k
    prec1, prec5 = accuracy(output.data, target, topk=(1, 5))
    """
    maxk = max(topk)
    # size函式:總元素的個數
    batch_size = target.size(0) 

    # topk函式選取output前k大個數
    _, pred = output.topk(maxk, 1, True, True)
    ##########不瞭解t()k
    pred = pred.t()
    correct = pred.eq(target.view(1, -1).expand_as(pred))   
    res = []
    for k in topk:
        correct_k = correct[:k].view(-1).float().sum(0, keepdim=True)
        res.append(correct_k.mul_(100.0 / batch_size))

用到的pytorch函式:max、topk、eq

以top5,batch_size = 64為例:

1.topk()的使用:獲取output中最大的5個數,如對第一張圖片的output為[0.03,0.05,0.04,0.1.......(共1000個)],經過topk()後得到[0.2,0.19,0.18,0.17,0.16](不清楚是不是降序),共64張圖片,可理解為一個5 X 64 的矩陣(其實是Tensor),每一列的數即為該張圖片的預測值。

2.target為64個值,代表每張圖片所屬的類別,經過target.view()後變成[1,2,3,4,5,6,7.....],再經過expand_as(pred),該函式代表將target.view()擴充套件到跟pred相同的維度,可理解為5 X 64的矩陣,每一列的值都相同。

3.經過torch.eq()後,correct為一個5X64的矩陣,每一列僅有一個值為1,如第一列[0,0,1,0,0],代表top1錯誤,top3正確,每一行代表top-k的值,在for迴圈中,可自由選擇k的值,當K的值為4時,取correct前4行,通過.float().sum()操作,計算top4中1的總個數,除以batch_size,則能得到top4的準確率。