1. 程式人生 > >焦點損失函式 Focal Loss 與 GHM

焦點損失函式 Focal Loss 與 GHM

文章來自公眾號【機器學習煉丹術】 ## 1 focal loss的概述 焦點損失函式 Focal Loss(2017年何凱明大佬的論文)被提出用於密集物體檢測任務。 當然,在目標檢測中,可能待檢測物體有1000個類別,然而你想要識別出來的物體,只是其中的某一個類別,這樣其實就是一個樣本非常不均衡的一個分類問題。 而Focal Loss簡單的說,就是解決樣本數量極度不平衡的問題的。 說到樣本不平衡的解決方案,相比大家是知道一個混淆矩陣的f1-score的,但是這個好像不能用在訓練中當成損失。而Focal loss可以在訓練中,**讓小數量的目標類別增加權重,讓分類錯誤的樣本增加權重**。 先來看一下簡單的二值交叉熵的損失: ![](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_2cd06c171db2ee3778f627ced22607a7.jpg) - y’是模型給出的預測類別概率,y是真實樣本。就是說,如果一個樣本的真實類別是1,預測概率是0.9,那麼$-log(0.9)$就是這個損失。 - 講道理,一般我不喜歡用二值交叉熵做例子,用多分類交叉熵做例子會更舒服。 **** **【然後看focal loss的改進】:** ![](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_75ba936c5e559243f5868716a859d1a3.jpg) 這個增加了一個$(1-y')^\gamma$的權重值,怎麼理解呢?就是如果給出的正確類別的概率越大,那麼$(1-y')^\gamma$就會越小,說明**分類正確的樣本的損失權重小**,反之,**分類錯誤的樣本的損權重大**。 **** **【focal loss的進一步改進】:** ![](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_a6c9f37691888d4814486417ea9b2561.jpg) 這裡增加了一個$\alpha$,這個alpha在論文中給出的是0.25,這個就是**單純的降低正樣本或者負樣本的權重,來解決樣本不均衡的問題**。 兩者結合起來,就是一個可以解決樣本不平衡問題的損失focal loss。 **** 【總結】: 1. $\alpha$解決了樣本的不平衡問題; 2. $\beta$解決了難易樣本不平衡的問題。讓樣本更重視難樣本,忽視易樣本。 3. 總之,Focal loss會的關注順序為:樣本少的、難分類的;樣本多的、難分類的;樣本少的,易分類的;樣本多的,易分類的。 ## 2 GHM - GHM是Gradient Harmonizing Mechanism。 這個GHM是為了解決Focal loss存在的一些問題。 **【Focal Loss的弊端1】** 讓模型過多的關注特別難分類的樣本是會有問題的。樣本中有一些異常點、離群點(outliers)。所以模型為了擬合這些非常難擬合的離群點,就會存在過擬合的風險。 ### 2.1 GHM的辦法 Focal Loss是從置信度p的角度入手衰減loss的。而GHM是一定範圍內建信度p的樣本數量來衰減loss的。 首先定義了一個變數**g**,叫做**梯度模長(gradient norm)**: ![](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_a3dbcb7519660eab730c4cce291a7c99.jpg) 可以看出這個梯度模長,其實就是模型給出的置信度$p^*$與這個樣本真實的標籤之間的差值(距離)。**g越小,說明預測越準,說明樣本越容易分類。** 下圖中展示了g與樣本數量的關係: ![](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_4efea42fe0389c6d862d25d5b23c6084.jpg) **【從圖中可以看到】** - 梯度模長接近於0的樣本多,也就是易分類樣本是非常多的 - 然後樣本數量隨著梯度模長的增加迅速減少 - 然後當梯度模長接近1的時候,樣本的數量又開始增加。 GHM是這樣想的,對於梯度模長小的易分類樣本,我們忽視他們;但是focal loss過於關注難分類樣本了。**關鍵是難分類樣本其實也有很多!**,如果模型一直學習難分類樣本,那麼可能模型的精確度就會下降。所以GHM對於難分類樣本也有一個衰減。 那麼,GHM對易分類樣本和難分類樣本都衰減,那麼真正被關注的樣本,就是那些不難不易的樣本。而抑制的程度,可以根據樣本的數量來決定。 這裡定義一個**GD,梯度密度**: $$GD(g)=\frac{1}{l(g)}\sum_{k=1}^N{\delta(g_k,g)}$$ - $GD(g)$是計算在梯度g位置的梯度密度; - $\delta(g_k,g)$就是樣本k的梯度$g_k$是否在$[g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]$這個區間內。 - $l(g)$就是$[g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]$這個區間的長度,也就是$\epsilon$ **總之,$GD(g)$就是梯度模長在$[g-\frac{\epsilon}{2},g+\frac{\epsilon}{2}]$內的樣本總數除以$\epsilon$.** 然後把每一個樣本的交叉熵損失除以他們對應的梯度密度就行了。 $$L_{GHM}=\sum^N_{i=1}{\frac{CE(p_i,p_i^*)}{GD(g_i)}}$$ - $CE(p_i,p_i^*)$表示第i個樣本的交叉熵損失; - $GD(g_i)$表示第i個樣本的梯度密度; ### 2.2 論文中的GHM 論文中呢,是把梯度模長劃分成了10個區域,因為置信度p是從0\~1的,所以梯度密度的區域長度就是0.1,比如是0\~0.1為一個區域。 下圖是論文中給出的對比圖: ![](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_f3d4ad69cd9d14f5dc3c40c6d2089288.jpg) **【從圖中可以得到】** - 綠色的表示交叉熵損失; - 藍色的是focal loss的損失,發現梯度模長小的損失衰減很有效; - 紅色是GHM的交叉熵損失,發現梯度模長在0附近和1附近存在明顯的衰減。 當然可以想到的是,GHM看起來是需要整個樣本的模型估計值,才能計算出梯度密度,才能進行更新。也就是說mini-batch看起來似乎不能用GHM。 在GHM原文中也提到了這個問題,如果光使用mini-batch的話,那麼很可能出現不均衡的情況。 【我個人覺得的處理方法】 1. 可以使用上一個epoch的梯度密度,來作為這一個epoch來使用; 2. 或者一開始先使用mini-batch計算梯度密度,然後模型收斂速度下降之後,再使用第一種方式進行更新。 ## 3 python實現 上面講述的關鍵在於focal loss實現的功能: 1. **分類正確的樣本的損失權重小,分類錯誤的樣本的損權重大**。 2. **樣本過多的類別的權重較小** 在CenterNet中預測中心點位置的時候,也是使用了Focal Loss,但是稍有改動。 ### 3.1 概述 [![](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_4bf0036197b9b9a6210cc724bb6e129b.jpg)](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_4bf0036197b9b9a6210cc724bb6e129b.jpg) 這裡面和上面講的比較類似,我們忽視腳標。 - 假設$Y=1$,那麼預測的$\hat{Y}$越靠近1,說明預測的約正確,然後$(1-\hat{Y})^\alpha$就會越小,從而體現**分類正確的樣本的損失權重小**;otherwize的情況也是這樣。 - 但是這裡的otherwize中多了一個$(1-Y)^\beta$,這個是用來平衡樣本不均衡問題的,在後面的程式碼部分會提到CenterNet的熱力圖。就會明白這個了。 ### 3.2 程式碼講解 下面通過程式碼來理解: ```python class FocalLoss(nn.Module): def __init__(self): super().__init__() self.neg_loss = _neg_loss def forward(self, output, target, mask): output = torch.sigmoid(output) loss = self.neg_loss(output, target, mask) return loss ``` 這裡面的output可以理解為是一個1通道的特徵圖,每一個pixel的值都是模型給出的置信度,然後通過sigmoid函式轉換成0~1區間的置信度。 而target是CenterNet的熱力圖,這一點可能比較難理解。打個比方,一個10\*10的全都是0的特徵圖,然後這個特徵圖中只有一個pixel是1,那麼這個pixel的位置就是一個目標檢測物體的中心點。有幾個1就說明這個圖中有幾個要檢測的目標物體。 然後,如果一個特徵圖上,全都是0,只有幾個孤零零的1,未免顯得過於稀疏了,直觀上也非常的不平滑。所以CenterNet的熱力圖還需要對這些1為中心做一個高斯 [![](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_0615e8fab0b921d58bf1c78558d2bdc2.jpg)](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_0615e8fab0b921d58bf1c78558d2bdc2.jpg) 可以看作是一種平滑: [![](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_3578060eb6a5cdc81708d40478dbfe93.jpg)](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_3578060eb6a5cdc81708d40478dbfe93.jpg) 可以看到,數字1的四周是同樣的數字。這是一個以1為中心的高斯平滑。 **** 這裡我們回到上面說到的$(1-Y)^\beta$: [![](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_4bf0036197b9b9a6210cc724bb6e129b.jpg)](http://helloworld2020.net/wp-content/uploads/2020/06/wp_editor_md_4bf0036197b9b9a6210cc724bb6e129b.jpg) 對於數字1來說,我們計算loss自然是用第一行來計算,但是對於1附近的其他點來說,就要考慮$(1-Y)^\beta$了。越靠近1的點的$Y$越大,那麼$(1-Y)^\beta$就會越小,這樣從而降低1附近的權重值。其實這裡我也講不太明白,就是根據距離1的距離降低負樣本的權重值,從而可以實現**樣本過多的類別的權重較小**。 **** 我們回到主題,對output進行sigmoid之後,與output一起放到了neg_loss中。我們來看什麼是neg_loss: ```python def _neg_loss(pred, gt, mask): pos_inds = gt.eq(1).float() * mask neg_inds = gt.lt(1).float() * mask neg_weights = torch.pow(1 - gt, 4) loss = 0 pos_loss = torch.log(pred) * torch.pow(1 - pred, 2) * pos_inds neg_loss = torch.log(1 - pred) * torch.pow(pred, 2) * \ neg_weights * neg_inds num_pos = pos_inds.float().sum() pos_loss = pos_loss.sum() neg_loss = neg_loss.sum() if num_pos == 0: loss = loss - neg_loss else: loss = loss - (pos_loss + neg_loss) / num_pos return loss ``` 先說一下,這裡面的mask是根據特定任務中加上的一個小功能,就是在該任務中,一張圖片中有一部分是不需要計算loss的,所以先用過mask把那個部分過濾掉。這裡直接忽視mask就好了。 從```neg_weights = torch.pow(1 - gt, 4)```可以得知$\beta=4$,從下面的程式碼中也不難推出,$\alpha=2$,剩下的內容就都一樣了。 把每一個pixel的損失都加起來,除以目標物體的數量