關於Focal Loss【轉自以學習、回憶】
是解決樣本不均衡問題的一種方法,面試常問,但是自己一知半解 遂尋文學習
來源:CSDN GHZhao_GIS_RS
連結:https://blog.csdn.net/u014311125/article/details/109470137
轉載一篇以學習、回憶。
轉載正文開始 有刪減
個人覺的要真正理解Focal Loss,有三個關鍵點需要清楚,分別對應基礎公式,超引數α,超引數γ。
一、二分類(sigmoid)和多分類(softmax)的交叉熵損失表達形式是有區別的。
二、理解什麼是難分類樣本,什麼是易分類樣本?搞清難易分類樣本是搞清楚Focal Loss中的超引數γ作用的關鍵。
三、負樣本的α值到底該是0.25還是0.75呢?這個問題對應Focal Loss中的超引數α的調參。
理解上面三點應該就能搞清楚二分類Focal Loss的基本思想,然後就可以推廣到多分類問題上。
理解關鍵點一:基礎公式
二分類和多分類的交叉熵的區別具體可以參考文章《一文搞懂交叉熵損失》(https://www.cnblogs.com/wangguchangqing/p/12068084.html#autoid-0-2-0)
1.1、二分類交叉熵
在做二分類的任務時,一般是用sigmoid作為最後的啟用函式,輸出只有一個代表樣本為正的概率值p,二分類非正即負,所以樣本為負的概率值為1-p。
則以sigmoid作為啟用函式的二分類任務交叉熵損失的計算公式為:
1.2、多分類交叉熵
在做多分類的時候,一般是以softmax作為最後的啟用函式的,輸出有多個值,對應每個分類的概率值,和為1。
則以sofmax作為啟用函式的多分類任務的交叉熵損失計算公式為
其中p c p_{c}p
c
表示softmax啟用函式輸出結果中第c類的對應的值。
注意:論文中是基於以sigmoid為啟用函式來作為二分類交叉熵損失的。我在最開始學Focal Loss的時候老是將sigmoid和softmax混著看,一會用sigmoid來套公式,一會用softmax來套公式,很容易把自己搞蒙。
文章的備註裡也指出可以很容易將Focal Loss應用於多分類,為了簡單起見,文章中關注的是二分類情況。
理解關鍵點二:
論文將交叉熵損失公式做了進一步的簡化:
其中
所以:
這裡pt的理解比較關鍵。pt的大小實際能反映出樣本難易分類的程度。舉個例子,當樣本為正樣本(y=1)時,如果模型預測的p=0.3,表示模型預測該樣本為負樣本,模型預測錯誤,
pt=0.3,如果模型預測的p=0.8,表示模型預測該樣本為正樣本,模型預測正確,
pt=0.8。當樣本為負樣本(y=0)時,如果模型預測的p=0.3,表示模型判斷該樣本為負樣本,判斷正確,=1-p=0.7。如果模型輸出的p=0.8,表示模型判斷該樣本為正樣本,模型預測錯誤,pt=1-p=0.2.對應下表
可以看到,不管是正樣本還是負樣本,模型預測時pt都很大,預測錯誤時pt很小,所以pt代表了模型對樣本預測正確的概率。
接下來看論文中一上來就給的一張圖。
橫座標是pt,可以看出作者指出pt∈(0.6,1)區間為易分類樣本。針對上邊的例子再囉嗦幾句,對於一個正樣本,如果模型得到的預測的p總是在0.5以上,則說明該樣本很容易被分類正確,所以是易分類樣本,此時pt=p,pt也總是在0.5以上,如果模型得到的預測的p總是在0.5以下,則說明該樣本很難被正確分類,所以為難分類樣本,此時pt也總是在0.5以下;同理對於一個負樣本,模型預測的p很容易在0.5以下,表明模型很容易將樣本正確分類,所以是易分類樣本,pt=1-p,pt總是在0.5以上,如果模型得到的預測的p總是在0.5以上,則說明針對這類樣本模型總是分類錯誤,所以是難分類樣本,pt=1-p,pt總是在0.5以下。
總結一下:易分類樣本的特徵pt>0.5 難分類樣本特徵:pt<0.5 pt值越大,表示預測越準確。
(自注:這裡p值咋來的?怎麼判斷的Pt 不還是需要p值? 這個p值還是訓練過程才能出來吧)
2.2 γ引數
在訓練模型的時候,我們希望模型更加關注難分類樣本,所以會考慮將難分類樣本在損失函式中的比重加大。
作者在原始的二分類交叉熵函式中增加了一項,對原始交叉熵損失做了衰減。
經過對pt的分析可知,難分類樣本的pt值小,1-pt大。易分類樣本的pt值大,1-pt值小。不管是難分類樣本還是易分類樣本,Focal Loss相對於原始的CE loss都做了衰減,只是難分類樣本相對於易分類樣本衰減的少。這裡超引數γ決定了衰減的程度。γ損失越大,損失衰減越明顯。對上面例子:
理解關鍵點三:超引數at
我們在做實際模型訓練的時候,經常會遇到各類樣本數量比例不平衡的情況,對於二分類任務,負樣本的數量遠遠多於正樣本,導致模型更多關注在負樣本上,忽略正樣本。因此在使用交叉熵損失的時候通常會增加一個平衡引數用來調節正負樣本的比重。
at的定義和pt的定義類似。應該是
可以知道at代表了正樣本的權重,1-at代表了負樣本的權重,這兩個值應該是正負樣本數量比例的反比,如正樣本佔0.2,負樣本數量佔0.8。那麼at=0.8,1-at=0.2。以詞來達到平衡正樣本的目的,這樣理解看來是沒有問題的。
所以Focal Loss的最終表示式:
作者指出加入at平衡引數比不加時精度有所提升。給出了實驗引數,at=0.25 γ=2時精度最高。這時就有一個問題了。at代表計算損失時正樣本的調節權重,而正樣本數量一半小於負樣本,所以正樣本的權重應該大於負樣本的權重, 那作者實驗中最佳的正樣本權重(α {\alpha}α=0.25)為啥比負樣本權重(1-α {\alpha}α=0.75)還要低呢?明明負樣本的數量已經遠遠大於正樣本的數量了,為啥還要增加損失函式中負樣本的比重呢?這不是矛盾嗎?
其實作者在論文裡給出瞭解釋。有兩處:
這段有兩點資訊:
1.at代表了樣本數量較少的類的權重,也就是絕大多數情況下的正樣本。
2.at與γ是相互作用的,γ增加 a應該稍微降低。
這段話有三點意思:
1.低a對應高γ
2.負樣本分類,權重已經被降低很多了,所以無需給正樣本再增加權重。
3.Focal Loss種γ佔主要地位。
多分類
接下來吧Focal Loss推廣到多分類任務中。
接下來我們把Focal Loss推廣到多分類任務中,看看多分類中的Focal Loss公式應該是怎樣的。
首先我們再來回顧一下二分類的Focal Loss的推導過程,其實就是在CE Loss的基礎上增加了兩項因子
,其中為用來調整難易分類樣本的比重,at對經過係數衰減後的損失進行調整。那麼
這裡就有一個問題了:at和在多分類任務中該怎麼表示?
最初的二分類交叉熵損失為:
對應的Focal Loss為
將at pt展開
還是二分類,我們把啟用函式換成softmax,我們知道sotmax輸出的是每個類的概率值,和為1。使用softmax時樣本的標籤為onehot形式(y1,y2),二分類情況下,第1類標籤為(1,0),第二類標籤為(0,1).假設softmax輸出為(p1,p2)分別對應1,2類的概率。
則以softmax為啟用的二分類的交叉熵損失為
先加入衰減係數γ
再加入a
因為標籤是onehot形式,某類樣本的標籤中的只只有再對應位置上為1 其餘為0,所以上式寫成
其中ac代表第c類樣本的權重,pc代表softmax輸出的第c類的概率值
後有程式碼