1. 程式人生 > 其它 >ICLR2020 | 解決長尾分佈的解耦學習方法

ICLR2020 | 解決長尾分佈的解耦學習方法

Decoupling representation and classifier for long-tailed recognition

程式碼連結:https://github.com/facebookresearch/classifier-balancing

1. 主要貢獻

長尾分佈資料集是目前訓練模型的一個很大的挑戰,模型在這類資料集上通常會在 head-classes (即數量較多的類別)上overfitting,而在tail-classes(即數量較少的類別)上under-fitting。解決imbalanced的問題常用的方法有:1)re-sampling dataset;2)re-weighting loss function; 3)把head-classes的特徵遷移給tail-classes等。

該論文通過設定一系列的實驗,發現以下現象:

  • 把訓練過程解耦成了兩部分:1)representations learning (即特徵提取)和 2) classification 能夠有效提高模型在長尾分佈資料集上的效能
  • 作者發現以下兩種方法(在 representations learning 過程中同時優化訓練分類器)能提高效能:
    • 固定feature,然後使用class-balanced 取樣策略retrain分類器
    • 對分類器的權重加約束(正則)也可以提高效能
  • 以上方式用在像ResNet這些常用模型上也能在Long-Tailed (LT)資料集上取得不錯的效果

2. Representations Learning

2.1 Data Re-sampling

每個樣本被取樣的概率可以表示成如下:\(C\)表示類別數量, \(n_j\)表示第 j 類的樣本數,\(q\in\{0,1,0.5\}\)分別表示不同的取樣策略。

\[p_{j}=\frac{n_{j}^{q}}{\sum_{i=1}^{C} n_{i}^{q}} \tag{1} \]
  1. Instance-balanced (IB) sampling:這個就是最普通也是最常用的取樣策略,即每個樣本被取樣的概率均等,對應公式(1)中的 \(q=1\)
  2. Class-balanced (CB) sampling: 這個就是說每個類別被取樣的概率相等,比如我們總共有4類,每次取樣的batch包含64個樣本,那麼每個batch中一定包含4個類別,每個類別的數量都是16,只不過類別裡的樣本被取樣的概率就是相等的。具體的實現可以參考catalyst.data.sampler.BatchBalanceClassSampler
    [程式碼]。公式(1)中\(q=0\)時表示每個類別被取樣的概率相等
  3. Progressively-balanced sampling:這個其實就是將上面 Instance-和Class- balanced做了結合,即下式, \(t,T\)分別表示當前的epoch和總的epoch數。
\[p_{j}^{\mathrm{PB}}(t)=\left(1-\frac{t}{T}\right) p_{j}^{\mathrm{IB}}+\frac{t}{T} p_{j}^{\mathrm{CB}} \tag{2} \]
  1. Square-root sampling: 對應公式(1)中\(q=0.5\)

2.2 Loss re-weighting

比較常見的方法有 focal loss,或者給tail-classes賦予更高的權重等

3. Classification

上一節總結了常用的學習特徵的訓練方法,這一節總結常用的訓練分類器的方法。

  1. Classifier Re-training (cRT): 這個就是比較常規的做法,即把 feature representations固定住,然後使用class-balanced sampling 對classifier做finetune
  2. Nearest Class Mean classifier (NCM): 這個是非引數方法,即先使用訓練集計算出 \(C\) 個類別的中心 feature tensor,然後每次做預測的時候使用 cosine similarity或者 MSE loss計算出每個樣本離這些中心feature的距離,離誰更近就預測屬於哪一類,這類似於KNN演算法
  3. \(\tau\)-normalized classifier :我們知道在 TL 資料集上,模型在預測的時候會傾向於把樣本都預測成類別多的那一類,極端情況甚至全都預測成同一類。假設這一類是第 \(i\) 類,這個時候很有可能是因為最後預測器(即全連線層)的第 \(i\) 類的權重值遠大於其他類別的權重,所以一種解決辦法就是給分類器的權重加上正則項,公式如下,\(\tau\) 是一個超引數,當\(\tau=1\)時,下式就等價於普通的 L2正則。一般取值是在0到1之間。
\[\widetilde{w_{i}}=\frac{w_{i}}{\left\|w_{i}\right\|^{\tau}} \tag{3} \]
  1. Learnable weight scaling (LWS):公式3中的分母是依賴於權重值,當然我們也可以讓分母設定成一個可學習的引數 \(f_i\),它的初始值和公式3一致(如下式)。在優化 \(f_i\)的過程中,representations和classifier的引數都是固定住的。
\[\widetilde{w_{i}}=f_{i} * w_{i}, \text { where } f_{i}=\frac{1}{\left\|w_{i}\right\|^{\tau}} \tag{4} \]

4. 實驗

4.1 實驗設定

因為長尾分佈資料集中有的類別可能只有幾張圖片,有的可能有上千張圖片,所以之前常用的Acc並不能有效表達出模型效能的好壞,所以後面論文給出了不同類別的分類準確率

  • All: 所有類別的acc
  • Many-shot: 圖片數量大於100的類別的acc
  • Medium-shot:圖片數量在20到100之間的類別的acc
  • Few-shot:圖片數量小於20的類別的acc

4.2 Sampling Strategies & Decoupled Learning

從Figure1我們能看到一下幾個現象:

  1. 只看4個影象的 Joint (即backbone和classifier同時訓練)那一列,我們可以看到隨著取樣策略的改善(從Instance到Progressively-balanced),Medium和Few 類別以及整體(All)的accuracy是穩步提升的。但是對於 Many類別,它的accuracy在 Instance-balanced情況下是最高的,這個也符合預期,因為這個時候模型會更加側重於資料多的類別。所以實驗結果表明 對於Joint的訓練模式,資料取樣非常重要。
  2. 論文中給出了3個decoupled learning的方法,分別是 NCM, cRT和\(\tau\)-norm。上圖可以看到除了Many-shot,這三個方法在其他3個類別上都比Joint訓練模式表現更好
  3. 一個很有意思的實驗結果是,在3個解耦學習的方法上,IB 取樣策略訓練得到的模型反而表現最好。換句話說,如果我們使用解耦的訓練方式,我們可能不用太花心思在資料取樣上

Figure 2 (左) 給出了不同訓練模式下 classifier權重的norm值。圖中Class Index是按照類別包含的樣本數降序排列的,即class-0含有組多樣本。

  • 可以看到對於Joint模式,weight norm值是逐漸減少的,即class-0的norm值最大。顯然當這個norm值遠大於其他類別的norm值時,模型很可能會將所有樣本都只預測成class-0。
  • cRT, \(\tau\)-norm和LWS都有效提高了Medium和Few classes的weight norm。

Figure 2 (右) 給出了 \(\tau\)-norm方法\(\tau\)對結果的影響。可以看到增加τ的大小能明顯改善 Few classes的準確率,但是同時Many classes會對應減少。Medium和All 的準確率先增後降,而且後期降得特別厲害,所以τ值的選擇也比較重要。

4.3 實驗結果對比

作者在3個TL資料集上做了實驗,可以看到提升效果都比較明顯。

微信公眾號:AutoML機器學習
MARSGGBO原創
如有意合作或學術討論歡迎私戳聯絡~
郵箱:[email protected]