1. 程式人生 > 其它 >NLP(四十一):解決樣本不均衡FocalLoss與GHM

NLP(四十一):解決樣本不均衡FocalLoss與GHM

Focal Loss for Dense Object Detection 是ICCV2017的Best student paper,文章思路很簡單但非常具有開拓性意義,效果也非常令人稱讚。

GHM(gradient harmonizing mechanism) 發表於 “Gradient Harmonized Single-stage Detector",AAAI2019,是基於Focal loss的改進,也是個人推薦的一篇深度學習必讀文章。

第一部分 Focal Loss

Focal Loss的引入主要是為了解決難易樣本數量不平衡(注意,有區別於正負樣本數量不平衡)的問題,實際可以使用的範圍非常廣泛,為了方便解釋,還是拿目標檢測的應用場景來說明:

單階段的目標檢測器通常會產生高達100k的候選目標,只有極少數是正樣本,正負樣本數量非常不平衡。我們在計算分類的時候常用的損失——交叉熵的公式如下:

(1)

為了解決正負樣本不平衡的問題,我們通常會在交叉熵損失的前面加上一個引數,即:

(2)

但這並不能解決全部問題。根據正、負、難、易,樣本一共可以分為以下四類:

儘管平衡了正負樣本,但對難易樣本的不平衡沒有任何幫助。而實際上,目標檢測中大量的候選目標都是像下圖一樣的易分樣本。

這些樣本的損失很低,但是由於數量極不平衡,易分樣本的數量相對來講太多,最終主導了總的損失。而本文的作者認為,易分樣本(即,置信度高的樣本)對模型的提升效果非常小,模型應該主要關注與那些難分樣本(這個假設是有問題的,是GHM的主要改進物件)

這時候,Focal Loss就上場了!

一個簡單的思想:把高置信度(p)樣本的損失再降低一些不就好了嗎!

(3)

舉個例,取2時,如果,,損失衰減了1000倍!

Focal Loss的最終形式結合了上面的公式(2). 這很好理解,公式(3)解決了難易樣本的不平衡,公式(2)解決了正負樣本的不平衡,將公式(2)與(3)結合使用,同時解決正負難易2個問題!

最終的Focal Loss形式如下:

實驗表明取2,取0.25的時候效果最佳。

這樣以來,訓練過程關注物件的排序為正難>負難>正易>負易。

這就是Focal Loss,簡單明瞭但特別有用。

Focal Loss的實現:

def py_sigmoid_focal_loss(pred,
                          target,
                          weight=None,
                          gamma=2.0,
                          alpha=0.25,
                          reduction='mean',
                          avg_factor=None):
    pred_sigmoid = pred.sigmoid()
    target = target.type_as(pred)
    pt = (1 - pred_sigmoid) * target + pred_sigmoid * (1 - target)
    focal_weight = (alpha * target + (1 - alpha) *
                    (1 - target)) * pt.pow(gamma)
    loss = F.binary_cross_entropy_with_logits(
        pred, target, reduction='none') * focal_weight
    loss = weight_reduce_loss(loss, weight, reduction, avg_factor)
    return loss

這個程式碼很容易理解,

先定義一個pt:

然後計算:

focal_weight = (alpha * target + (1 - alpha) *(1 - target)) * pt.pow(gamma)

也就是這個公式:

再把BCE損失*focal_weight就行了

程式碼來自於mmdetection\mmdet\models\losses,這個python版的sigmoid_focal_loss實現就是讓你拿去學習的,真正使用的是cuda程式設計版。真是個人性化的好框架

第二部分 GHM

那麼,Focal Loss存在什麼問題呢?

首先,讓模型過多關注那些特別難分的樣本肯定是存在問題的,樣本中有離群點(outliers),可能模型已經收斂了但是這些離群點還是會被判斷錯誤,讓模型去關注這樣的樣本,怎麼可能是最好的呢?

其次,與的取值全憑實驗得出,且和要聯合起來一起實驗才行(也就是說,和的取值會相互影響)。

GHM(gradient harmonizing mechanism) 解決了上述兩個問題。

Focal Loss是從置信度p的角度入手衰減loss,而GHM是一定範圍置信度p的樣本數量的角度衰減loss。

文章先定義了一個梯度模長g:

程式碼如下:

g = torch.abs(pred.sigmoid().detach() - target)

其中是模型預測的概率,是 ground-truth的標籤,的取值為0或1.

g正比於檢測的難易程度,g越大則檢測難度越大。

至於為什麼叫梯度模長,因為g是從交叉熵損失求梯度得來的:

假定是樣本的輸出,我們知道,

那麼,可以求出

看下圖梯度模長與樣本數量的關係:

可以看到,梯度模長接近於0的樣本數量最多,隨著梯度模長的增長,樣本數量迅速減少,但是在梯度模長接近於1時,樣本數量也挺多。

GHM的想法是,我們確實不應該過多關注易分樣本,但是特別難分的樣本(outliers,離群點)也不該關注啊!

這些離群點的梯度模長d要比一般的樣本大很多,如果模型被迫去關注這些樣本,反而有可能降低模型的準確度!況且,這些樣本的數量也很多!

那怎麼同時衰減易分樣本和特別難分的樣本呢?太簡單了,誰的數量多衰減誰唄!那怎麼衰減數量多的呢?簡單啊,定義一個變數,讓這個變數能衡量出一定梯度範圍內的樣本數量——這不就是物理上密度的概念嗎?

於是,作者定義了梯度密度——本文最重要的公式:

表明了樣本1~N中,梯度模長分佈在範圍內的樣本個數,代表了區間的長度。

因此梯度密度的物理含義是:單位梯度模長g部分的樣本個數。

接下來就簡單了,對於每個樣本,把交叉熵CE×該樣本梯度密度的倒數即可!

用於分類的GHM損失, N是總的樣本數量。

梯度密度的詳細計算過程如下:

首先,把梯度模長範圍劃分成10個區域,這裡要求輸入必須經過sigmoid計算,這樣梯度模長的範圍就限制在0~1之間:

class GHMC(nn.Module):
    def __init__(self, bins=10, ......):
        self.bins = bins
        edges = torch.arange(bins + 1).float() / bins
......

>>> edges = tensor([0.0000, 0.1000, 0.2000, 0.3000, 0.4000, 
                  0.5000, 0.6000, 0.7000, 0.8000,0.9000, 1.0000])

edges是每個區域的邊界,有了邊界就很容易計算出梯度模長落入哪個區間內。

然後根據網路輸出pred和ground true計算loss:

注意,不管是Focal Loss還是GHM其實都是對不同樣本賦予不同的權重,所以該程式碼前面計算的都是樣本權重,最後計算GHM Loss就是呼叫了Pytorch自帶的binary_cross_entropy_with_logits,將樣本權重填進去。

# 計算梯度模長
g = torch.abs(pred.sigmoid().detach() - target)
# n 用來統計有效的區間數。
# 假如某個區間沒有落入任何梯度模長,密度為0,需要額外考慮,不然取個倒數就無窮了。
n = 0  # n valid bins
# 通過迴圈計算落入10個bins的梯度模長數量
for i in range(self.bins):
    inds = (g >= edges[i]) & (g < edges[i + 1]) & valid
    num_in_bin = inds.sum().item()
    if num_in_bin > 0:
        # 重點,所謂的梯度密度就是1/num_in_bin
        weights[inds] = num_labels / num_in_bin 
        n += 1
if n > 0:
    weights = weights / n
# 把上面計算的weights填到binary_cross_entropy_with_logits裡就行了
loss = torch.nn.functional.binary_cross_entropy_with_logits(
    pred, target, weights, reduction='sum') / num_labels

看看抑制的效果吧,也就是文章開頭的這張圖片:

同樣,對於迴歸損失:

,其中為修正的smooth L1 loss.

End~

因為本文著重論文的理解,很多細節沒有寫出,大家還是要去看一下原文的。

如果文中有錯誤還請批評指出!