ICLR2020 | 解決長尾分佈的解耦學習方法
Decoupling representation and classifier for long-tailed recognition
程式碼連結:https://github.com/facebookresearch/classifier-balancing
1. 主要貢獻
長尾分佈資料集是目前訓練模型的一個很大的挑戰,模型在這類資料集上通常會在 head-classes (即數量較多的類別)上overfitting,而在tail-classes(即數量較少的類別)上under-fitting。解決imbalanced的問題常用的方法有:1)re-sampling dataset;2)re-weighting loss function; 3)把head-classes的特徵遷移給tail-classes等。
該論文通過設定一系列的實驗,發現以下現象:
- 把訓練過程解耦成了兩部分:1)representations learning (即特徵提取)和 2) classification 能夠有效提高模型在長尾分佈資料集上的效能
- 作者發現以下兩種方法(在 representations learning 過程中同時優化訓練分類器)能提高效能:
- 固定feature,然後使用class-balanced 取樣策略retrain分類器
- 對分類器的權重加約束(正則)也可以提高效能
- 以上方式用在像ResNet這些常用模型上也能在Long-Tailed (LT)資料集上取得不錯的效果
2. Representations Learning
2.1 Data Re-sampling
每個樣本被取樣的概率可以表示成如下:\(C\)表示類別數量, \(n_j\)表示第 j 類的樣本數,\(q\in\{0,1,0.5\}\)分別表示不同的取樣策略。
\[p_{j}=\frac{n_{j}^{q}}{\sum_{i=1}^{C} n_{i}^{q}} \tag{1} \]- Instance-balanced (IB) sampling:這個就是最普通也是最常用的取樣策略,即每個樣本被取樣的概率均等,對應公式(1)中的 \(q=1\)。
- Class-balanced (CB) sampling: 這個就是說每個類別被取樣的概率相等,比如我們總共有4類,每次取樣的batch包含64個樣本,那麼每個batch中一定包含4個類別,每個類別的數量都是16,只不過類別裡的樣本被取樣的概率就是相等的。具體的實現可以參考
catalyst.data.sampler.BatchBalanceClassSampler
- Progressively-balanced sampling:這個其實就是將上面 Instance-和Class- balanced做了結合,即下式, \(t,T\)分別表示當前的epoch和總的epoch數。
- Square-root sampling: 對應公式(1)中\(q=0.5\)
2.2 Loss re-weighting
比較常見的方法有 focal loss,或者給tail-classes賦予更高的權重等
3. Classification
上一節總結了常用的學習特徵的訓練方法,這一節總結常用的訓練分類器的方法。
- Classifier Re-training (cRT): 這個就是比較常規的做法,即把 feature representations固定住,然後使用class-balanced sampling 對classifier做finetune
- Nearest Class Mean classifier (NCM): 這個是非引數方法,即先使用訓練集計算出 \(C\) 個類別的中心 feature tensor,然後每次做預測的時候使用 cosine similarity或者 MSE loss計算出每個樣本離這些中心feature的距離,離誰更近就預測屬於哪一類,這類似於KNN演算法
- \(\tau\)-normalized classifier :我們知道在 TL 資料集上,模型在預測的時候會傾向於把樣本都預測成類別多的那一類,極端情況甚至全都預測成同一類。假設這一類是第 \(i\) 類,這個時候很有可能是因為最後預測器(即全連線層)的第 \(i\) 類的權重值遠大於其他類別的權重,所以一種解決辦法就是給分類器的權重加上正則項,公式如下,\(\tau\) 是一個超引數,當\(\tau=1\)時,下式就等價於普通的 L2正則。一般取值是在0到1之間。
- Learnable weight scaling (LWS):公式3中的分母是依賴於權重值,當然我們也可以讓分母設定成一個可學習的引數 \(f_i\),它的初始值和公式3一致(如下式)。在優化 \(f_i\)的過程中,representations和classifier的引數都是固定住的。
4. 實驗
4.1 實驗設定
因為長尾分佈資料集中有的類別可能只有幾張圖片,有的可能有上千張圖片,所以之前常用的Acc並不能有效表達出模型效能的好壞,所以後面論文給出了不同類別的分類準確率
- All: 所有類別的acc
- Many-shot: 圖片數量大於100的類別的acc
- Medium-shot:圖片數量在20到100之間的類別的acc
- Few-shot:圖片數量小於20的類別的acc
4.2 Sampling Strategies & Decoupled Learning
從Figure1我們能看到一下幾個現象:
- 只看4個影象的 Joint (即backbone和classifier同時訓練)那一列,我們可以看到隨著取樣策略的改善(從Instance到Progressively-balanced),Medium和Few 類別以及整體(All)的accuracy是穩步提升的。但是對於 Many類別,它的accuracy在 Instance-balanced情況下是最高的,這個也符合預期,因為這個時候模型會更加側重於資料多的類別。所以實驗結果表明 對於Joint的訓練模式,資料取樣非常重要。
- 論文中給出了3個decoupled learning的方法,分別是 NCM, cRT和\(\tau\)-norm。上圖可以看到除了Many-shot,這三個方法在其他3個類別上都比Joint訓練模式表現更好
- 一個很有意思的實驗結果是,在3個解耦學習的方法上,IB 取樣策略訓練得到的模型反而表現最好。換句話說,如果我們使用解耦的訓練方式,我們可能不用太花心思在資料取樣上。
Figure 2 (左) 給出了不同訓練模式下 classifier權重的norm值。圖中Class Index是按照類別包含的樣本數降序排列的,即class-0含有組多樣本。
- 可以看到對於Joint模式,weight norm值是逐漸減少的,即class-0的norm值最大。顯然當這個norm值遠大於其他類別的norm值時,模型很可能會將所有樣本都只預測成class-0。
- cRT, \(\tau\)-norm和LWS都有效提高了Medium和Few classes的weight norm。
Figure 2 (右) 給出了 \(\tau\)-norm方法\(\tau\)對結果的影響。可以看到增加τ的大小能明顯改善 Few classes的準確率,但是同時Many classes會對應減少。Medium和All 的準確率先增後降,而且後期降得特別厲害,所以τ值的選擇也比較重要。
4.3 實驗結果對比
作者在3個TL資料集上做了實驗,可以看到提升效果都比較明顯。