1. 程式人生 > 其它 >[論文分享] Overcoming Catastrophic Forgetting in Incremental Few-Shot Learning by Finding Flat Minima

[論文分享] Overcoming Catastrophic Forgetting in Incremental Few-Shot Learning by Finding Flat Minima

我又來給大家分享PAPER了!!!
今天給大家分享的這篇論文是NIPS’ 2021的一篇Few-Shot增量學習(FSCIL)文章,這篇文章通過固定backbone和prototype得到一個簡單的baseline,發現這個baseline已經可以打敗當前IL和IFSL的很多SOTA方法,基於此通過借鑑robust optimize的方法,提出了在base training訓練時通過flat local minima來對後面的session進行fine-tune novel classes,解決災難性遺忘問題。

No. content
PAPER {NIPS' 2021} Overcoming Catastrophic Forgetting in Incremental Few-Shot Learning by Finding Flat Minima[1]
URL 論文地址
CODE 程式碼地址

1.1 Motivation

  • 不同於現有方法在學習新任務時嘗試克服災難性遺忘問題,這篇文章提出在訓練base classes時就提出策略來解決這個問題。
  • 作者提出找到base training function的flat local minima,最小值附近loss小,作者認為base classes分離地更好。(直覺上,flat local minima會比sharp的泛化效果更好,參閱下圖[2])

1.2 Contribution

  • 作者發現一個簡單的baseline model,只要在base classes上訓練,不在new tasks上進行適應,就超過了現有的SOTA方法,說明災難性遺忘問題非常嚴重。
  • 作者提出在primitive stage來解決災難性遺忘問題,通過在base classes上訓練時找到flat minima region並在該region內學習新任務,模型能夠更好地克服遺忘問題。

1.3 A Simple Baseline

作者提出了一個簡單的baseline,模型只在base classes上進行訓練,在後續的session上直接進行推理。

Training(t=1)
在session 1上對特徵提取器進行訓練,並使用一個全連線層作為分類器,使用CE Loss作為損失函式,從session 2(\(t\geq2\))開始將特徵提取器固定住,不使用novel classes進行任何fine-tune操作。
Inference(test)


使用均值方式獲得每個類的prototype,然後通過歐氏距離\(d(·,·)\)採用最近鄰方式進行分類。分類器的公式如下:

其中\(p_c\)表示類別\(c\)的prototype,\(N_c\)表示類別\(c\)的訓練圖片數量。同時作者將\(C^T\)中所有類的prototypes儲存下來用於後續的evaluation。
作者表示通過這種儲存old prototype的方式就打敗了現有的SOTA方法,證明了災難性遺忘非常嚴重。

1.4 Method

核心想法就是在base training的過程中找到函式的flat local minima \(\theta^*\),並在後續的few-shot session中在flat region進行fine-tune,這樣可以最大限度地保證在novel classes上進行fine-tune時避免遺忘知識。在後續增量few-shot sessions(\(t\geq2\))中,在這個flat region進行fine-tune模型引數來學習new classes

1.4.1 尋找Base Training的flat local minima

Definition 1(\(b\)-Flat Local Minima) Given a real-valued objective function \(\mathcal{L}(z; θ)\), for any \(b > 0\), \(\theta^*\) is a b-flat local minima of \(\mathcal{L}(z; θ)\), if the following conditions are satisfied.

為了找到base training function的近似flat local minima,作者提出新增一些隨機噪聲到模型引數,噪聲可以被多次新增以獲得相似但不同的loss function,直覺上,flat local minima附近的引數向量有小的函式值。
假設模型的引數\(\theta=\{\phi,\psi\}\)\(\phi\)表示特徵提取網路的引數,\(\psi\)表示分類器的引數。\(z\)表示一個有類標訓練樣本,損失函式\(\mathcal{L}:\ \mathbb{R}^{d_z} \rightarrow \mathbb{R}\)。我們的目標就是最小化期望損失函式

\(P(z)\)是資料分佈\(P(\epsilon)\)是噪聲分佈,\(z\)\(\epsilon\)是相互獨立的。
因此最小化期望損失是不可能的,所以這裡我們最小化他的近似,empirical loss,

\(\epsilon_i\)\(P(\epsilon)中的噪聲樣本\)\(M\)是取樣次數。這個loss的前半部分是為了找到flat region,它的特徵提取網路引數\(\phi\)可以很好地區分base classes。第二部分是通過MSE Loss的設計為了讓prototype儘量保持不變, 避免模型遺忘過去的知識。

1.4.2 在Flat Region內進行IFSL

作者認為雖然flat region很小,但對於few-shot的少量樣本來說,足夠對模型進行迭代更新。

通過歐氏距離使用基於度量的分類演算法來fine-tune模型引數。

1.4.3 收斂性分析

我們的目標是找到一個flat region使模型效果較好。然後,通過最小化期望損失(噪聲\(\epsilon\)和資料\(z\)的聯合分佈)。為了近似這個期望損失,我們在每次迭代中多次從\(P(\epsilon)\)取樣,並使用隨機梯度下降(SGD)優化目標函式。後面是相關的理論證明,感興趣的可以自行閱讀分析。

【參考文獻】
[1] Shi G, Chen J, Zhang W, et al. Overcoming Catastrophic Forgetting in Incremental Few-Shot Learning by Finding Flat Minima[J]. Advances in Neural Information Processing Systems, 2021, 34.
[2] He H, Huang G, Yuan Y. Asymmetric valleys: Beyond sharp and flat local minima[J]. arXiv preprint arXiv:1902.00744, 2019.