1. 程式人生 > >處理樣本不平衡LOSS—Focal Loss

處理樣本不平衡LOSS—Focal Loss

0 前言

Focal Loss是為了處理樣本不平衡問題而提出的,經時間驗證,在多種任務上,效果還是不錯的。在理解Focal Loss前,需要先深刻理一下交叉熵損失,和帶權重的交叉熵損失。然後我們從樣本權重的角度出發,理解Focal Loss是如何分配樣本權重的。Focal是動詞Focus的形容詞形式,那麼它究竟Focus在什麼地方呢?(詳細的程式碼請看Gitee)。

1 交叉熵

1.1 交叉熵損失(Cross Entropy Loss)

有\(N\)個樣本,輸入一個\(C\)分類器,得到的輸出為\(X\in \mathcal{R}^{N\times C}\),它共有\(C\)類;其中某個樣本的輸出記為\(x\in \mathcal{R}^{1\times C}\),即\(x[j]\)是\(X\)的某個行向量,那麼某個交叉熵損失可以寫為如下公式:

\[ \text{loss}\left( x,\text{class} \right) =-\log \left( \frac{\exp \left( x\left[\text{class} \right] \right)}{\sum_j{\exp\left( x\left[ j \right] \right)}} \right) =-x\left[\text{class} \right] +\log \left( \sum_j{\exp\left( x\left[ j \right] \right)} \right) \tag{1-1} \]
其中\(\text{class}\in [0,\ C)\)是這個樣本的類標籤

,如果給出了類標籤的權重向量\(W\in \mathcal{R}^{1\times C}\),那麼帶權重的交叉熵損失可以更改為如下公式:

\[ \operatorname{loss}(x, \text {class})=W[\text {class}]\left(-x[\text {class}]+\log \left(\sum_{j} \exp (x[j])\right)\right) \tag{1-2} \]

最終對這個\(N\)個樣本的損失求和或者求平均

\[ \ell = \begin{cases} \sum_{i}^{N}{\text{loss}(x^{(i)},\ \text{class}^{(i)})}&\text{, sum}\\ \dfrac{1}{N}\sum_{i}^{N}{\text{loss}(x^{(i)},\ \text{class}^{(i)})}&\text{, mean} \end{cases} \tag{1-3} \]

這個就是我們平時經常用到的交叉熵損失了。

1.2 二分類交叉熵損失(Binary Cross Entropy Loss)

上面所提到的交叉熵損失是適用於多分類(二分類及以上)的,但是它的公式看起來似乎與我們平時在書上或論文中看到的不一樣,一般我們常見的交叉熵損失公式如下:

\[ l = -y\log{\hat{y}}-(1-y)\log{(1-\hat{y})} \]

這是一個典型的二分類交叉熵損失,其中\(y\in\{0,\ 1\}\)表示標籤值,\(\hat{y}\in[0,\ 1]\)表示分類模型的類別1預測值。上面這個公式是一個綜合的公式,它等價於:

\[ l = \begin{cases} -\log{\hat{y}_0} &y=0 \\ -\log{\hat{y}_1} &y=1 \end{cases}; \quad \text{where}\quad \hat{y}_0+\hat{y}_1 = 1 \]

其中\(\hat{y}_0, \hat{y}_1\)是二分類模型輸出的2個偽概率值

例:如果二分類模型是神經網路,且最後一層為: 2個神經元+Softmax,那麼\(\hat{y}_0, \hat{y}_1\)就對應著這兩個神經元的輸出值。當然它也可以帶上類別的權重。

同樣地,有\(N\)個樣本,輸入一個2分類器,得到的輸出為\(X\in \mathcal{R}^{N\times 2}\),再經過Softmax函式,\(\hat{Y}=\sigma(X)\in \mathcal{R}^{N\times 2}\),標籤為\(Y\in \mathcal{R}^{N\times 2}\),每個樣本的二分類損失記為\(l^{(i)}, i=0,1,2,\cdots,N\),最終對這個\(N\)個樣本的損失求和或者求平均

\[ \ell = \begin{cases} \sum_{i}^{N}l^{(i)}&\text{, sum}\\ \dfrac{1}{N}\sum_{i}^{N}l^{(i)}&\text{, mean} \end{cases}; \ \ \ l^{(i)} = -y^{(i)}\log{\hat{y}^{(i)}}-(1-y^{(i)})\log{(1-\hat{y}^{(i)})} \]

注:如果一次只訓練一個樣本,即\(N=1\),那麼上面帶類別權重的損失中的權重是無效的。因為權重是相對的,某一個樣本的權重大,那麼必然需要有另一個樣本的權重小,這樣才能體現出這一批樣本中某些樣本的重要性。\(N=1\)時,已沒有權重的概念,它是唯一的,也是最重要的。\(N=1\),或者說batch_size=1這種情況在訓練視訊\文章資料時,是會常出現的。由於我們顯示/記憶體的限制,而視訊/文章資料又比較大,一次只能訓練一個樣本,此時我們就需要注意權重的問題了。

2 Focal Loss

2.1 基本思想

一般來講,Focal Loss(以下簡稱FL)[1]是為解決樣本不平衡的問題,但是更準確地講,它是為解決難分類樣本(Hard Example)易分類樣本(Easy Example)的不平衡問題。對於樣本不平衡,其實通過上面的帶權重的交叉熵損失便可以一定程度上解決這個問題,但是在實際問題中,以權重來解決樣本不平衡問題的效果不夠理想,此時我們應當思考,表面上我們的樣本不平衡,但實質上導致效果不好的原因也許並不是簡單地因為樣本不平衡,而是因為樣本中存在一些Hard Example,同時存在許多Easy Example,Easy Example雖然容易被分類器分辨,損失較小,但是由於其數量大,它們累積起來依然於大於Hard Example的Loss值,因此我們需要給Hard Example較大的權重,而Easy Example較小的權重。

那麼什麼叫Hard Example,什麼叫Easy Example呢?看下面的圖就知道了。

圖2-1 Hard Example 圖2-2 Easy Example1 圖2-3 Easy Example2 圖2-4 Example Space

假設,我們的任務是訓練一個分類器,分類出人和馬,對於上面的三張圖,圖2-2和圖2-3應該是非常容易判斷出來的,但是圖2-1就是不那麼容易了,它即有人的特徵,又有馬的特徵,非常容易混淆。這種樣本雖然在資料集中出現的頻率可能並不高,但是想要提高分類器的效能,需要著力解決這種樣本分類問題。

提出Hard Example和Easy Example後,可以將樣本空間劃分為如圖2-4所示的樣本空間。其中縱軸為多數類樣本(Majority Class)少數類樣本(Minority Class),上面的帶權重的交叉熵損失只能解決Majority Class和Minority Class的樣本不平衡問題,並沒有考慮Hard Example和Easy Example的問題,Focal Loss的提出就是為解決這個難易樣本的分類問題。

2.2 Focal Loss解決方案

要解決難易樣本的分類問題,首先就需要找出Hard Example和Easy Example。這對於神經網路來說,應該是一件比較容易的事情。如圖2-6所示,這是一個5分類的網路,神經網路的最後一層輸出時,加上一個Softmax或者Sigmoid就會得到輸出的偽概率值,代表著模型預測的每個類別的概率,

圖2-6 Easy Example Classifier Output 圖2-7 Hard Example Classifier Output

圖2-6中,樣本標籤為1,分類器輸出值最大的為第1個神經元(以0開始計數),這剛好預測準確,而且其輸出值2也比其它神經元的輸出值要大不少,因此可以認為這是一個易分類樣本(Easy Example);圖2-7的樣本標籤是3,分類器輸出值最大的為第4個神經元,並且這幾個神經元的輸出值都相差不大,神經網路無法準確判斷這個樣本的類別,所以可以認為這是一個難分類樣本(Hard Example)。其實說白了,判斷Easy/Hard Example的方法就是看分類網路的最後的輸出值。如果網路預測準確,且其概率較大,那麼這是一個Easy Example,如果網路輸出的概率較小,這是一個Hard Example。下面用數學公式嚴謹地表達來Focal Loss的表示式。

令一個\(C\)類分類器的輸出為\(\boldsymbol{y}\in \mathcal{R}^{C\times 1}\),定義函式\(f\)將輸出\(\boldsymbol{y}\)轉為偽概率值\(\boldsymbol{p}=f(\boldsymbol{y})\),當前樣本的類標籤為\(t\),記\(p_t=\boldsymbol{p}[t]\),它表示分類器預測為\(t\)類的概率值,再結合上面的交叉熵損失,定義Focal Loss為:

\[ \text{FL} = -(1-p_t)\log(p_t) \tag{2-1} \]

這實質就是交叉熵損失前加了一個權重,只不過這個權重有點不一樣的來頭。為了更好地控制前面權重的大小,可以給前面的權重係數新增一個指數\(\gamma\),那麼更改式(2-1):

\[ \text{FL} = -(1-p_t)^\gamma\log(p_t) \tag{2-2} \]

其中\(\gamma\)一值取值為2就好,\(\gamma\)取值為0時與交叉熵損失等價,\(\gamma\)越大,就越抑制Easy Example的損失,相對就會越放大Hard Example的損失。同時為解決樣本類別不平衡的問題,可以再給式(2-2)新增一個類別的權重\(\alpha_t\)(這個類別權重上面的交叉熵損失已經實現):

\[ \text{FL} = -\alpha_t(1-p_t)^\gamma\log(p_t) \tag{2-3} \]

到這裡,Focal Loss理論就結束了,非常簡單,但是有效。

3 Focal Loss實現(Pytorch)

3.1 交叉熵損失實現(numpy)

為了更好的理解Focal Loss的實現,先理解交叉熵損失的實現,我這裡用numpy簡單地實現了一下交叉熵損失。

import numpy as np

def cross_entropy(output, target):
    out_exp = np.exp(output)
    out_cls = np.array([out_exp[i, t] for i, t in enumerate(target)])
    ce = -np.log(out_cls / out_exp.sum(1))
    return ce

程式碼中第5行,可能稍微有點難以理解,它不過是為了找出標籤對應的輸出值。比如第2個樣本的標籤值為3,那它分類器的輸出應當選擇第2行,第3列的值。

3.2 Focal Loss實現

下面的程式碼的10~12行:依據輸出,計算概率,再將其轉為focal_weight;15~16行,將類權重和focal_weight新增到交叉熵損失,得到最終的focal_loss;18~21行,實現meansum兩種reduction方法,注意求平均不是簡單的直接平均,而是加權平均。

class FocalLoss(nn.Module):
    def __init__(self, gamma=2, weight=None, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.gamma = gamma
        self.weight = weight
        self.reduction = reduction

    def forward(self, output, target):
        # convert output to presudo probability
        out_target = torch.stack([output[i, t] for i, t in enumerate(target)])
        probs = torch.sigmoid(out_target)
        focal_weight = torch.pow(1-probs, self.gamma)

        # add focal weight to cross entropy
        ce_loss = F.cross_entropy(output, target, weight=self.weight, reduction='none')
        focal_loss = focal_weight * ce_loss

        if self.reduction == 'mean':
            focal_loss = (focal_loss/focal_weight.sum()).sum()
        elif self.reduction == 'sum':
            focal_loss = focal_loss.sum()

        return focal_loss

注:上面實現中,output的維度應當滿足output.dim==2,並且其形狀為(batch_size, C),且target.max()<C

總結

Focal Loss從2017年提出至今,該論文已有2000多引用,足以說明其有效性。其實從本質上講,它也只不過是給樣本重新分配權重,它相對類別權重的分配方法,只不過是將樣本空間進行更為細緻的劃分,從圖2-4很容易理解,類別權重的方法,只是將樣本空間劃分為藍色線上下兩個部分,而加入難易樣本的劃分,又可以將空間劃分為左右兩個部分,如此,樣本空間便被劃分4個部分,這樣更加細緻。其實藉助於這個思想,我們是否可以根據不同任務的需求,更加細緻劃分我們的樣本空間,然後再相應的分配不同的權重呢?

參考文獻

[1] Lin, T.-Y., Goyal, P., Girshick, R., He, K., & Dollár, P. (2017). Focal loss for dense object detection. Paper presented at the Proceedings of the IEEE international conference on computer vision.