Focal loss論文解析
阿新 • • 發佈:2020-10-06
Focal loss是目標檢測領域的一篇十分經典的論文,它通過改造損失函式提升了一階段目標檢測的效能,背後關於類別不平衡的學習的思想值得我們深入地去探索和學習。正負樣本失衡不僅僅在目標檢測演算法中會出現,在別的機器學習任務中同樣會出現,這篇論文為我們解決類似問題提供了一個很好的啟發,所以我認為無論是否從事目標檢測領域相關工作,都可以來看一看這篇好論文。
論文的關鍵性改進在於對**損失函式的改造**以及對**引數初始化**的設定。
首先是對損失函式的改造。論文中指出,限制目標檢測網路效能的一個關鍵因素是類別不平衡。二階段目標檢測演算法相比於一階段目標檢測演算法的優點在於,二階段的目標檢測演算法通過候選框篩選演算法(proposal stage)過濾了大部分背景樣本(負樣本),使得正負樣本比例適中;而一階段的目標檢測演算法中,需要處理大量的負樣本,使得包含目標的正樣本資訊被淹沒。這使得一階段目標檢測演算法的識別準確度上比不上二階段的目標檢測演算法。
**為了解決這個問題,Focal loss使用了動態加權的思想,對於置信度高的樣本,損失函式進行降權;對於置信度低的樣本,損失函式進行加權,使得網路在反向傳播時,置信度低的樣本能夠提供更大的梯度佔比,即從未學習好的樣本中獲取更多的資訊(就像高中時期的錯題本一樣,對於容易錯的題目,包含了更多的資訊量,需要更加關注這種題目;而對於屢屢正確的題目,可以少點關注,說明已經掌握了這型別的題目)**。
其巧妙之處就在於,通過了網路本身輸出的概率值(置信度)去構建權重,實現了自適應調整權重的目的。
## 公式的講解
Focal loss是基於交叉熵損失構建的,二元交叉熵的公式為
$$
\mathrm{CE}(p, y)=\left\{\begin{array}{ll}
-\log (p) & \text { if } y = +1 \\
-\log (1-p) & \text { y = -1 }
\end{array}\right.
$$
為了方便表示,定義$p_t$為分類正確的概率
$$
p_{t}=\left\{\begin{array}{ll}
p & \text { if } y = +1 \\
1-p & \text { y = -1 }
\end{array}\right.
$$
則交叉熵損失表示為$CE(p,y)=CE(p_t)=-log(p_t)$。如前文所述,通過置信度對損失進行縮放得到Focal loss。
$$
FL(p_t)=-\alpha_t(1-p_t)^\gamma log(p_t)= \alpha_t(1-p_t)^\gamma\times CE(p_t)
$$
其中,$\alpha_{1}=\left\{\begin{array}{ll}
\alpha & \text { if } y = +1 \\
1-\alpha & \text { y = -1 }
\end{array}\right.$為縮放乘數(直接調整正負樣本的權重),$\gamma$為縮放因子,$(1-p_t)$可以理解為分類錯誤的概率。公式中起到關鍵作用的部分是$(1-p_t)^\gamma$。為了給易分樣本降權,通常設定$\gamma>1$。
對於正確分類的樣本,$p_t \to 1 \Rightarrow(1-p_t) \to 0$,受到$\gamma$的影響很大,$(1-p_t)^\gamma \approx 0$;
對於錯誤分類的樣本,$p_t \to 0 \Rightarrow(1-p_t) \to 1$,受到$\gamma$的影響較小,$(1-p_t)^\gamma \approx (1-p_t)$,對於難分樣本的降權較小。
Focal loss本質上是通過置信度給易分樣本進行更多的降權,對難分樣本進行更少的降權,實現對難分樣本的關注。
## 引數初始化
論文中還有一個比較重要的點是對於子網路最後一層權重的初始化方式,關係到網路初期訓練的效能。這裡結合論文和我看過的一篇博文進行詳細的展開。常規的深度學習網路初始化演算法,使用的分佈是高斯分佈,根據概率論知識,兩個高斯分佈的變數的乘積仍然服從高斯分佈。假設權重$w\sim N(\mu_w,\sigma_w^2)$,最後一層的特徵$x\sim N(\mu_x,\sigma_x^2)$,則$wx \sim N(\mu_{wx},\sigma_{wx}^2)$。
$$
\mu_{wx}=\frac{\mu_w \sigma_x^2+\mu_x \sigma_w^2}{\sigma_x^2+\sigma_w^2}\\
\sigma_{wx}=\frac{\sigma_x^2\sigma_w^2}{\sigma_x^2+\sigma_w^2}
$$
其中$x$的分佈取決於網路的結果,$w$的分佈引數為$\mu_w=0,\sigma_w^2=10^{-4}$,只需$x$的分佈引數滿足$\sigma_x^2\gg 10^{-4},\sigma_x^2\gg10^{-4}\mu_x$成立,有如下的不等式。(一般情況下,這兩個條件是成立的。)
$$
\mu_{wx}=\frac{\mu_w \sigma_x^2+\mu_x \sigma_w^2}{\sigma_x^2+\sigma_w^2}=\frac{10^{-4}\mu_x}{\sigma_x^2+10^{-4}}\ll\frac{10^{-4}\mu_x}{10^{-4}\mu_x+10^{-4}}=\frac{1}{1+\frac{1}{\mu_x}}\approx0 \text{由於}\mu_x\text{一般為分數(網路的輸入經過歸一化到0至1,隨著網路加深的連乘,分數會越來越小)}\\
\sigma_{wx}=\frac{\sigma_x^2\sigma_w^2}{\sigma_x^2+\sigma_w^2}=\frac{10^{-4}}{1+\frac{10^{-4}}{\sigma_x^2}}\approx10^{-4} \text{由於}\sigma_x^2\gg10^{-4}
$$
根據上述推導,$wx$服從一個均值為0,方差很小的高斯分佈,可以在很大概率上認為它就等於0,所以網路最後一層的輸出為
$$
p=sigmoid(wx+b)=sigmoid(b)=\frac{1}{1+e^{-b}}=\pi
$$
令$\pi$為網路初始化時輸出為正類的概率,設定為一個很小的值(0.01),則網路在訓練初期,將樣本都劃分為負類,對於正類$p_t=0.01$,負類$p_t=0.99$,則訓練初期,正類都被大概率錯分,負類都被大概率正確分類,所以在訓練初期更加關注正類,避免初期的正類資訊被淹沒在負類資訊中。
## 總結
總的來說,Focal loss通過對損失函式的簡單改進,實現了一種自適應的困難樣本挖掘策略,使得網路在學習過程中關注更難學習的樣本,在一定程度上解決了正負樣本分佈不均衡的問題(由於正負樣本分佈不均衡,對於稀少的正樣本學習不足,導致正樣本普遍表現為難分樣本)。
## 參考資料
[論文原文](https://arxiv.org/abs/1708.02002)
[一篇不錯的解析部落格](https://leimao.github.io/blog/Focal-Loss-Expl