目標檢測筆記No.3 SSD框架的損失函式
SSD框架的損失函式
很有幸參加Datawhale 十二月組隊學習,本次筆記參考連結: 動手學CV-Pytorch第三章目標檢測3.5部分,並在此基礎上做出一些補充。
這裡涉及論文中的名稱,先驗框(default boxes/prior bbox),標籤框/目標框(ground truth box),預測框(prediction box). 該框架下損失函式,主要考慮三個內容分別是匹配策略、損失函式設計、線上難例挖掘。源程式通過一個Class 的結構包含了以上三部分,下面我將按我的理解,將程式分解貼於三個部分。
匹配策略
第一個原則:從ground truth box出發,尋找與每一個ground truth box有最大的jaccard overlap的prior bbox,這樣就能保證每一個ground truth box一定與一個prior bbox對應起來(jaccard overlap就是IOU)。 反之,若一個prior bbox沒有與任何ground truth進行匹配,那麼該prior bbox只能與背景匹配,就是負樣本。
第二個原則:從prior bbox出發,對剩餘的還沒有配對的prior bbox與任意一個ground truth box嘗試配對,只要兩者之間的jaccard overlap大於閾值(一般是0.5),那麼該prior bbox也與這個ground truth進行匹配。這意味著某個ground truth可能與多個Prior box匹配,這是可以的。但是反過來卻不可以,因為一個prior bbox只能匹配一個ground truth,如果多個ground truth與某個prior bbox的 IOU 大於閾值,那麼prior bbox只與IOU最大的那個ground truth進行匹配。注意:第二個原則一定在第一個原則之後進行。
通俗版理解:ground truth box真的有’後宮佳麗三千’的感覺,可以與多個prior box匹配;而prior box只能與一個ground truth box 匹配。
論文中的匹配部分的程式。
# For each image
for i in range(batch_size):
n_objects = boxes[i].size(0)
overlap = find_jaccard_overlap(boxes[i], self.priors_xy) # (n_objects, 441)
# For each prior, find the object that has the maximum overlap
overlap_for_each_prior, object_for_each_prior = overlap.max(dim=0) # (441)
# We don't want a situation where an object is not represented in our positive (non-background) priors -
# 1. An object might not be the best object for all priors, and is therefore not in object_for_each_prior.
# 2. All priors with the object may be assigned as background based on the threshold (0.5).
# To remedy this -
# First, find the prior that has the maximum overlap for each object.返回與object最切合的prior的編號
_, prior_for_each_object = overlap.max(dim=1) # (N_o)
# Then, assign each object to the corresponding maximum-overlap-prior. (This fixes 1.)
object_for_each_prior[prior_for_each_object] = torch.LongTensor(range(n_objects)).to(device) #???沒有變化
# To ensure these priors qualify, artificially give them an overlap of greater than 0.5. (This fixes 2.)
overlap_for_each_prior[prior_for_each_object] = 1.
損失函式設計
目標檢測包括分類問題和框的迴歸問題,損失函式就是兩者的加權和。其中下標(conf)為置信度損失,(loc)為定位損失。N表示有N對匹配的框。
定位損失,這裡輸入smoothL1( )函式的框資訊不是(x1,y1,x2,y2)也不是(cx, cy, w, h),而是進行過編碼之後的資訊(gcx, gcy, gw, gh) 也就是損失函式下面的公式。這裡定位損失擬合的是目標框與預測框之間的變換。換句話說,是經過一系列變換之後的框的資料資訊。
置信度損失,這裡就是分類問題的損失表示,角標p表示類別,i表示第i個prior box,j表示第j個ground truth box,x表示匹配變數(自己理解的說法)。
這類補充一下,smoothL1( ),該函式的特點是①當預測框與目標框差值過大時,梯度不至於過大②當預測框與目標框差值不大時,梯度不至於大小。觀察函式,我們看到小於1時,是二階函式;而大於1時,是一階函式。
from torch import nn
smooth_l1 = nn.L1Loss() #直接呼叫
以下是原始碼中損失函式部分,我把整個類拆開來看,最後再去github 看源程式能方便理解一些。部分損失計算包含難例挖掘,分在下一部分。
# LOCALIZATION LOSS
# Localization loss is computed only over positive (non-background) priors
loc_loss = self.smooth_l1(predicted_locs[positive_priors], true_locs[positive_priors]) # (), scalar
# Note: indexing with a torch.uint8 (byte) tensor flattens the tensor when indexing is across multiple dimensions (N & 441)
# So, if predicted_locs has the shape (N, 441, 4), predicted_locs[positive_priors] will have (total positives, 4)
# CONFIDENCE LOSS
# Confidence loss is computed over positive priors and the most difficult (hardest) negative priors in each image
# That is, FOR EACH IMAGE,
# we will take the hardest (neg_pos_ratio * n_positives) negative priors, i.e where there is maximum loss
# This is called Hard Negative Mining - it concentrates on hardest negatives in each image, and also minimizes pos/neg imbalance
# Number of positive and hard-negative priors per image
n_positives = positive_priors.sum(dim=1) # (N)
n_hard_negatives = self.neg_pos_ratio * n_positives # (N)
# First, find the loss for all priors
conf_loss_all = self.cross_entropy(predicted_scores.view(-1, n_classes), true_classes.view(-1)) # (N * 441)
conf_loss_all = conf_loss_all.view(batch_size, n_priors) # (N, 441)
# We already know which priors are positive
conf_loss_pos = conf_loss_all[positive_priors] # (sum(n_positives))
# Next, find which priors are hard-negative
# To do this, sort ONLY negative priors in each image in order of decreasing loss and take top n_hard_negatives
conf_loss_neg = conf_loss_all.clone() # (N, 441)
conf_loss_neg[positive_priors] = 0. # (N, 441), positive priors are ignored (never in top n_hard_negatives)
線上難例挖掘
一般情況下negative prior bboxes數量 >> positive prior bboxes數量,直接訓練會導致網路過於重視負樣本,預測效果很差。為了保證正負樣本儘量平衡,我們這裡使用SSD使用的線上難例挖掘策略(hard negative mining),即依據confidience loss對屬於負樣本的prior bbox進行排序,只挑選其中confidience loss高的bbox進行訓練,將正負樣本的比例控制在positive:negative=1:3。
conf_loss_neg, _ = conf_loss_neg.sort(dim=1, descending=True) # (N, 441), sorted by decreasing hardness
hardness_ranks = torch.LongTensor(range(n_priors)).unsqueeze(0).expand_as(conf_loss_neg).to(device) # (N, 441)
hard_negatives = hardness_ranks < n_hard_negatives.unsqueeze(1) # (N, 441)
conf_loss_hard_neg = conf_loss_neg[hard_negatives] # (sum(n_hard_negatives))
# As in the paper, averaged over positive priors only, although computed over both positive and hard-negative priors
conf_loss = (conf_loss_hard_neg.sum() + conf_loss_pos.sum()) / n_positives.sum().float() # (), scalar
# return TOTAL LOSS
return conf_loss + self.alpha * loc_loss
源程式格式
class MultiBoxLoss(nn.Module):
def __init__(self, priors_cxcy, threshold=0.5, neg_pos_ratio=3, alpha=1.):
super(MultiBoxLoss, self).__init__()
# 資料初始化
pass
def forward(self, predicted_locs, predicted_scores, boxes, labels):
# 中間包含了匹配+線上難例挖掘+損失計算
return conf_loss + self.alpha * loc_loss
傳送門:
SSD論文.
動手學CV-Pytorch第三章目標檢測3.5部分.