NLP(四十一):解決樣本不均衡FocalLoss與GHM
Focal Loss for Dense Object Detection 是ICCV2017的Best student paper,文章思路很簡單但非常具有開拓性意義,效果也非常令人稱讚。
GHM(gradient harmonizing mechanism) 發表於 “Gradient Harmonized Single-stage Detector",AAAI2019,是基於Focal loss的改進,也是個人推薦的一篇深度學習必讀文章。
第一部分 Focal Loss
Focal Loss的引入主要是為了解決難易樣本數量不平衡(注意,有區別於正負樣本數量不平衡)的問題,實際可以使用的範圍非常廣泛,為了方便解釋,還是拿目標檢測的應用場景來說明:
單階段的目標檢測器通常會產生高達100k的候選目標,只有極少數是正樣本,正負樣本數量非常不平衡。我們在計算分類的時候常用的損失——交叉熵的公式如下:
(1)
為了解決正負樣本不平衡的問題,我們通常會在交叉熵損失的前面加上一個引數,即:
(2)
但這並不能解決全部問題。根據正、負、難、易,樣本一共可以分為以下四類:
儘管平衡了正負樣本,但對難易樣本的不平衡沒有任何幫助。而實際上,目標檢測中大量的候選目標都是像下圖一樣的易分樣本。
這些樣本的損失很低,但是由於數量極不平衡,易分樣本的數量相對來講太多,最終主導了總的損失。而本文的作者認為,易分樣本(即,置信度高的樣本)對模型的提升效果非常小,模型應該主要關注與那些難分樣本(這個假設是有問題的,是GHM的主要改進物件)
這時候,Focal Loss就上場了!
一個簡單的思想:把高置信度(p)樣本的損失再降低一些不就好了嗎!
(3)
舉個例,取2時,如果,,損失衰減了1000倍!
Focal Loss的最終形式結合了上面的公式(2). 這很好理解,公式(3)解決了難易樣本的不平衡,公式(2)解決了正負樣本的不平衡,將公式(2)與(3)結合使用,同時解決正負難易2個問題!
最終的Focal Loss形式如下:
實驗表明取2,取0.25的時候效果最佳。
這樣以來,訓練過程關注物件的排序為正難>負難>正易>負易。
這就是Focal Loss,簡單明瞭但特別有用。
Focal Loss的實現:
def py_sigmoid_focal_loss(pred,
target,
weight=None,
gamma=2.0,
alpha=0.25,
reduction='mean',
avg_factor=None):
pred_sigmoid = pred.sigmoid()
target = target.type_as(pred)
pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
focal_weight = (alpha * target + (1 - alpha) *
(1 - target)) * pt.pow(gamma)
loss = F.binary_cross_entropy_with_logits(
pred, target, reduction='none') * focal_weight
loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
return loss
這個程式碼很容易理解,
先定義一個pt:
然後計算:
focal_weight = (alpha * target + (1 - alpha) *(1 - target)) * pt.pow(gamma)
也就是這個公式:
再把BCE損失*focal_weight就行了
程式碼來自於mmdetection\mmdet\models\losses,這個python版的sigmoid_focal_loss實現就是讓你拿去學習的,真正使用的是cuda程式設計版。真是個人性化的好框架
第二部分 GHM
那麼,Focal Loss存在什麼問題呢?
首先,讓模型過多關注那些特別難分的樣本肯定是存在問題的,樣本中有離群點(outliers),可能模型已經收斂了但是這些離群點還是會被判斷錯誤,讓模型去關注這樣的樣本,怎麼可能是最好的呢?
其次,與的取值全憑實驗得出,且和要聯合起來一起實驗才行(也就是說,和的取值會相互影響)。
GHM(gradient harmonizing mechanism) 解決了上述兩個問題。
Focal Loss是從置信度p的角度入手衰減loss,而GHM是一定範圍置信度p的樣本數量的角度衰減loss。
文章先定義了一個梯度模長g:
程式碼如下:
g = torch.abs(pred.sigmoid().detach() - target)
其中是模型預測的概率,是 ground-truth的標籤,的取值為0或1.
g正比於檢測的難易程度,g越大則檢測難度越大。
至於為什麼叫梯度模長,因為g是從交叉熵損失求梯度得來的:
假定是樣本的輸出,我們知道,
那麼,可以求出
看下圖梯度模長與樣本數量的關係:
可以看到,梯度模長接近於0的樣本數量最多,隨著梯度模長的增長,樣本數量迅速減少,但是在梯度模長接近於1時,樣本數量也挺多。
GHM的想法是,我們確實不應該過多關注易分樣本,但是特別難分的樣本(outliers,離群點)也不該關注啊!
這些離群點的梯度模長d要比一般的樣本大很多,如果模型被迫去關注這些樣本,反而有可能降低模型的準確度!況且,這些樣本的數量也很多!
那怎麼同時衰減易分樣本和特別難分的樣本呢?太簡單了,誰的數量多衰減誰唄!那怎麼衰減數量多的呢?簡單啊,定義一個變數,讓這個變數能衡量出一定梯度範圍內的樣本數量——這不就是物理上密度的概念嗎?
於是,作者定義了梯度密度——本文最重要的公式:
表明了樣本1~N中,梯度模長分佈在範圍內的樣本個數,代表了區間的長度。
因此梯度密度的物理含義是:單位梯度模長g部分的樣本個數。
接下來就簡單了,對於每個樣本,把交叉熵CE×該樣本梯度密度的倒數即可!
用於分類的GHM損失, N是總的樣本數量。
梯度密度的詳細計算過程如下:
首先,把梯度模長範圍劃分成10個區域,這裡要求輸入必須經過sigmoid計算,這樣梯度模長的範圍就限制在0~1之間:
class GHMC(nn.Module):
def __init__(self, bins=10, ......):
self.bins = bins
edges = torch.arange(bins + 1).float() / bins
......
>>> edges = tensor([0.0000, 0.1000, 0.2000, 0.3000, 0.4000,
0.5000, 0.6000, 0.7000, 0.8000,0.9000, 1.0000])
edges是每個區域的邊界,有了邊界就很容易計算出梯度模長落入哪個區間內。
然後根據網路輸出pred和ground true計算loss:
注意,不管是Focal Loss還是GHM其實都是對不同樣本賦予不同的權重,所以該程式碼前面計算的都是樣本權重,最後計算GHM Loss就是呼叫了Pytorch自帶的binary_cross_entropy_with_logits,將樣本權重填進去。
# 計算梯度模長
g = torch.abs(pred.sigmoid().detach() - target)
# n 用來統計有效的區間數。
# 假如某個區間沒有落入任何梯度模長,密度為0,需要額外考慮,不然取個倒數就無窮了。
n = 0 # n valid bins
# 通過迴圈計算落入10個bins的梯度模長數量
for i in range(self.bins):
inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
num_in_bin = inds.sum().item()
if num_in_bin > 0:
# 重點,所謂的梯度密度就是1/num_in_bin
weights[inds] = num_labels / num_in_bin
n += 1
if n > 0:
weights = weights / n
# 把上面計算的weights填到binary_cross_entropy_with_logits裡就行了
loss = torch.nn.functional.binary_cross_entropy_with_logits(
pred, target, weights, reduction='sum') / num_labels
看看抑制的效果吧,也就是文章開頭的這張圖片:
同樣,對於迴歸損失:
,其中為修正的smooth L1 loss.
End~
因為本文著重論文的理解,很多細節沒有寫出,大家還是要去看一下原文的。
如果文中有錯誤還請批評指出!