1. 程式人生 > >Incremental Learning of Object Detectors without Catastrophic Forgetting詳解

Incremental Learning of Object Detectors without Catastrophic Forgetting詳解

Incremental Learning of Object Detectors without Catastrophic Forgetting詳解

最近由於專案的需要在研究incremental learning在目標檢測方面的應用,剛好讀到了INRIA在2007年的一篇paper,採用蒸餾loss的方法來做incremental learning的,所以寫這篇部落格記錄下來。

概述

不懂什麼叫incremental learning或者是catastrophic forgetting的可以參考知乎這個連結,王乃巖介紹的非常完善,自己也學到了不少。
CNN用於目標檢測任務的缺陷——類別遺忘:假設CNN模型A為在一個物體檢測訓練集1上訓練得到的效能較好的檢測器,現在有另外一個訓練集2,其中物體類別與1不同,使用訓練集2在A的基礎上進行fine-tune得到模型B,模型B在訓練集2中的類別上可以達到比較好的檢測結果,但是在訓練集1中的類別上檢測效能就會大幅度下降;

本文目的:緩解CNN用於目標檢測任務的類別遺忘,在訓練集1中原始圖片不可得以及新圖片中不包含訓練集1中存在的類別的標註的情況下,在訓練集2上fine-tune模型A得到模型B,可以同時在訓練集1和2中的類別上獲得較好的檢測效能;

本文核心:在fine-tune模型A得到模型B的過程中提出一個新的損失函式,用於同時考慮網路在新的類別上的預測效能以及原始類別在新模型B和舊模型A上的響應差異,LOSS=新類別檢測LOSS+舊類別在模型A和模型B上的差異LOSS。

方法的核心:平衡新類別預測(即交叉熵損失)與新的蒸餾損失之間的相互作用的損失函式,其將原始和新網路的舊類別的響應之間的差異最小化。

網路結構

作者也提出:解決這個新增分類的問題可以再模型A上增加對新類別的預測分支,隨即初始化該分支後,用新類別資料fine-tune這個分支,但是這樣做會導致一個問題,此時得到的網路對原來N個類別的檢測效能會大幅下降。所以作者提出了一種新的loss,既能夠檢測出新的類,同時也能保證在舊的類的檢測準確率不會下降。網路結構如下:
在這裡插入圖片描述

Network A:It contains a frozen copy of the original detector。作用:1)檢測原始類別的bbox;2)蒸餾proposals並計算蒸餾loss;
Network B:用於對新增分類B的網路,結合模型A最終可以預測出新的類和舊的類;

作者指出:選擇fast-rcnn而非選擇faster-rcnn,因為faster-rcnn中有RPN層,其對類別有一定的敏感性,因為RPN可被訓練且共享卷積,,不利於最後蒸餾loss的計算,所以作者選基於edgeboxes的fast-rcnn,因為其類別對proposal不敏感。
在作者的這個fast-rcnn中,將vgg16替換為resnet50,並在最後一層stride!= 1的卷積層前加入了RoI pooling層,然後在接上剩下的卷積層和兩層FC連線每個類別的得分輸出和迴歸輸出,使用該主幹網路訓練用於檢測類別集合1的模型A。
loss_cls層評估分類代價。由真實分類u對應的概率決定:
L c l s L_{cls} =−log p u p_u
L c l s L_{cls} =−log⁡ p u p_u
loss_bbox評估檢測框定位代價。比較真實分類對應的預測引數 t u t_u 和真實平移縮放參數為 v v 的差別:
L l o c L_{loc} = Σ i = 1 4 Σ_{i=1}^4 g( t i u t_i^u v i v_i )

g為Smooth L1誤差,對outlier不敏感:
在這裡插入圖片描述
總代價為兩者加權和,如果分類為背景則不考慮定位代價:
L={Lcls+λLlocLclsu為前景u為背景
這個詳細的可以參考fast-rcnn原paper,這裡不詳說。

訓練方法

首先訓練一個fast-rcnn的網路結構使其能夠檢測原本的資料集 C a C_a ,這個網路結構記為A( C A C_A )。所以我們現在的目標是曾傑一個新的類資料集 C B C_B
我們對先前訓練得到的網路A( C A C_A )做兩份copies:一個凍結的網路通過蒸餾loss對原來的 C A C_A 進行檢測識別;另外一個B( C B C_B )被擴充用來檢測新的分類 C B C_B (在元資料中未出現或未被標註)。我們建立一個新的FC層用來只對新的分類檢測,然後將其output和原來的的輸出做concat,即:根據新增加的類別數對網路A進行擴充套件,即增加全連線層的輸出個數,得到初始化的Network B網路。新的層是採用和先前的網路A一樣的初始化方式進行隨機初始化的。現在我們的目標就是:訓練一個網路能夠僅僅使用新的資料,最後能夠識別出新增分類和舊分類的網路。
作者指出蒸餾loss是為了“keeping all the answers of the network the same or as close as possible”。如果我們訓練網路B( C B C_B )不做蒸餾的話,這個網路的效能在原來的類上將會急劇下降,這就是所謂的catastrophic forgetting(災難性遺忘)。Even if no object is detected by A(CA), the unnormalized logits (softmax input) carry enough information to “distill” the knowledge of the old classes from A( C A C_A ) to B( C B C_B ).

細節

對於每一個訓練圖片,隨機從128個RoI中選取64的背景得分最低的RoI,並分別得到其通過模型A後在舊的類別集合上的得分和迴歸目標,同樣得到其在通過模型B後在舊的類別集合上的得分和迴歸目標。
Loss函式包括logits(即softmax的input)和迴歸的outputs:
在這裡插入圖片描述
N:用於蒸餾的RoIs的數量(文章選的64)
| C A C_A |:原始資料的類別數
t A t_A :bounding box regression outputs
蒸餾logits不使用任何的smoothing,因為大多數的proposals已經經歷了smoothing在分數的分佈上。在我們的試驗中,在初始階段,新的和舊的網路的引數基本一致,所以沒必要smoothing來穩定其訓練。
所以總的損失函式定義如下:
在這裡插入圖片描述

取樣策略

作者實驗發現:選擇非背景proposal進行蒸餾學習相比隨機選擇proposal進行蒸餾學習得到的網路更檢測效能更好。
其他的作者做的一些實驗本文就不在這裡敘述了。隨後獻上paper和作者的程式碼