pytorch訓練ImageNet筆記(一)--accuracy
阿新 • • 發佈:2018-12-22
一:準確度的計算
# 計算準確度 def accuracy(output, target, topk=(1,5)): """Computes the [email protected] for the specified values of k prec1, prec5 = accuracy(output.data, target, topk=(1, 5)) """ maxk = max(topk) # size函式:總元素的個數 batch_size = target.size(0) # topk函式選取output前k大個數 _, pred = output.topk(maxk, 1, True, True) ##########不瞭解t()k pred = pred.t() correct = pred.eq(target.view(1, -1).expand_as(pred)) res = [] for k in topk: correct_k = correct[:k].view(-1).float().sum(0, keepdim=True) res.append(correct_k.mul_(100.0 / batch_size))
用到的pytorch函式:max、topk、eq
以top5,batch_size = 64為例:
1.topk()的使用:獲取output中最大的5個數,如對第一張圖片的output為[0.03,0.05,0.04,0.1.......(共1000個)],經過topk()後得到[0.2,0.19,0.18,0.17,0.16](不清楚是不是降序),共64張圖片,可理解為一個5 X 64 的矩陣(其實是Tensor),每一列的數即為該張圖片的預測值。
2.target為64個值,代表每張圖片所屬的類別,經過target.view()後變成[1,2,3,4,5,6,7.....],再經過expand_as(pred),該函式代表將target.view()擴充套件到跟pred相同的維度,可理解為5 X 64的矩陣,每一列的值都相同。
3.經過torch.eq()後,correct為一個5X64的矩陣,每一列僅有一個值為1,如第一列[0,0,1,0,0],代表top1錯誤,top3正確,每一行代表top-k的值,在for迴圈中,可自由選擇k的值,當K的值為4時,取correct前4行,通過.float().sum()操作,計算top4中1的總個數,除以batch_size,則能得到top4的準確率。