BMVC2018—Adversarial Learning for Semi-Supervised Semantic Segmentation
【筆記】Adversarial Learning for Semi-Supervised Semantic Segmentation
- Author:Wei-Chih Hung
- Organization:University of California Merced(加州大學默塞德分校)
- From: BMVC2018
- Citations:176
- Code:https://github.com/hfslyc/AdvSemiSeg
Abstract
本文使用對抗網路進行半監督語義分割。以往的判別器對整張影象判斷真假,本文設計的判別器,能夠在空間解析度上區分預測概率圖和真實分割的分佈。在半監督學習中,通過對未標註的影象的預測結果產生置信區域,提供了一定的監督訊號
2 Related work
Semantic segmentation
在弱監督學習中,分割網路不是由畫素級別標註的監督訊號中學習的,而是用類似影象級別的標籤、包圍框這種監督訊號,但是很難從弱監督訊號推斷分割物體的邊界資訊,所以人們常將全標註資料和弱標註資料一起訓練,即半監督學習
而本文使用全標註資料與無標註資料訓練網路,利用全卷積判別器作為無標註資料的監督訊號,實現半監督學習。
Generative adversarial networks
[29] 使用對抗網路做語義分割,但是與基線相比效能提升不大
[40] 用GAN做半監督語義分割,根據密集標籤生成的樣本不夠真實
3 Algorithm Overview
模型包括兩個模組
- 分割網路
\(H\times W\times 3\) - > 分割網路 -> \(H\times W\times C\) (其中\(C\)是分割類別的數量)
- 判別網路
輸入是類概率圖,類概率圖是由分割網路 或是 從ground truth label 得到的
輸出是\(H\times W\times 1\)的概率圖,畫素值\(p=1\)表示來自ground truth label,畫素值\(p=0\)表示來自分割網路
訓練過程:訓練時同時使用了帶有標註的影象和未標註的影象
- 當使用標註影象時,分割網路同時受到基於ground truth label的標準交叉熵損失\(L_{ce}\)
- 當使用未標註影象時,用分割網路得到初步分割結果,然後將初步分割結果送入判別網路得到置信度圖,將置信度圖作為監督訊號,用自學習的方法通過\(L_{semi}\)訓練分割網路
4 Semi-Supervised Training with Adversarial Network
4.1 Network Architecture
Segmentation network:使用帶有ResNet-101的DeeLab-v2模型,模型在ImageNet資料集和MSCOCO資料集中進行預訓練。去掉了多尺度融合,去掉最後一個分類層,將最後兩個卷積層的stride從2變為1,輸出的特徵圖解析度為原圖大小的1/8,為了擴大感受野,在conv4和conv5分別採用stride2和4的空洞卷積,在最後一層使用ASPP(Atrous Spatial Pyramid Pooling),最後上取樣至原圖大小,經過softmax後輸出
Discriminator network:有5個卷積層,卷積核\(4\times 4\),通道數為\(\{64, 128, 256, 512, 1\}\),stride為2,每個卷積層後(除了最後一層)是引數為0.2的Leaky-ReLU,最後一層使用上取樣,將輸出影象變為輸入影象大小,沒有使用BN層,因為BN只在batch size足夠大時才有效
4.2 Loss Function
輸入影象\(X_{n}\),大小為\(H\times W\times 3\)
分割網路\(S(\cdot)\),得到分割預測概率圖\(S(X_{n})\),大小\(H\times W\times C\),其中\(C\)是類數量
判別網路\(D(\cdot)\),其輸入是概率圖\(H\times W\times C\)(有兩種情況:1. 分割預測概率圖\(S(X_{n}\)),2. one-hot編碼的ground truth向量\(Y_{n}\)),其輸出是置信度圖\(H\times W\times 1\)
Discriminator network
\[L_{D} = -\sum_{h,w}(1-y_{n})\log (1-D(S(X_{n}))^{(h,w)}) + y_{n}\log(D(Y_{n})^{(h,w)}) \]當\(y_{n}=0\)時樣本來自分割網路,當\(y_{n}=1\)時樣本來自ground truth label。其中\(D(S(X_{n}))^{(h,w)}\)是\(X_{n}\)的置信度圖,為了將離散ground truth label轉換為\(C\)通道概率圖,使用one-hot編碼方法,每個通道表示相應類別的掩模
判別網路存在的一個問題是:ground truth的輸入是one-hot概率圖(值為0或1),而分割預測概率圖是概率(值為0~1),可能使判別網路學習到one-hot編碼的即為真
但是本文沒有這個問題,可能因為判別器的預測結果是空間置信度(即對每個畫素進行判別),增加了判別器的訓練難度
Segmentation network
\[L_{seg} = L_{ce} + \lambda_{adv} L_{adv} + \lambda_{semi}L_{semi} \]帶有標註的資料:輸入影象\(X_{n}\),其one-hot編碼的ground truth是\(Y_{n}\),分割結果\(S(X_{n})\),其交叉熵損失為:
\[L_{ce} = -\sum_{h,w}\sum_{c\in C} Y_{n}^{(h,w,c)}\log (S(X_{n})^{(h,w,c)}) \]對抗損失,愚弄判別器,通過將分割網路得到的預測概率圖視為真實分佈:
\[L_{adv} = -\sum_{h.w}\log(D(S(X_{n}))^{(h,w)}) \]未標註的資料:由於未標註的資料沒有ground truth,所以沒有使用\(L_{ce}\),但是依然使用了\(L_{adv}\)項,我們發現應該選擇比標註資料更小的\(\lambda_{adv}\)值,因為在沒有交叉熵損失的情況下,對抗損失為了擬合ground truth分佈可能會對預測進行過度修正
- 判別網路生成置信圖\(D(S(X_{n}))\),用閾值對置信圖進行二值化,突出信任區域
- 自學習的one-hot編碼ground truth \(\hat{Y}_{n}\)計算方法:\(\hat{Y}_{n}^{(h,w,c^{*})} = 1 \; if \; c^{*} = argmax_{c} S(X_{n}^{(h,w,c)})\)
其中\(I(\cdot)\)是指示函式,\(T_{semi}\)是閾值,設定為0.1到0.3之間,注意在訓練時,我們把自學習目標\(\hat{Y}_{n}\)和指示函式的值都看成常數,所以\(L_{semi}\)可以視為掩模空間交叉熵損失
5 Experimental Results
對於標註資料:\(\lambda_{adv} = 0.01\),\(\lambda_{semi}=0.1\);對於未標註資料:\(\lambda_{adv}=0.001\),\(\lambda_{semi}=0.2\)
先用標註資料訓練迭代5000次之後再交替訓練標註資料和未標註資料,在每次迭代中,只是用標註資料訓練判別器
Evaluation datasets
PASCAL VOC 2012:包括20中常見物品的影象分割,額外加入了SBD資料集中額外標註的影象,得到共10582張訓練集影象,1449張驗證集影象。使用大小\(321\times 321\)的隨機縮放和裁剪操作,訓練迭代共20K次,batch size為10
Cityscapes:包括50個車載行駛視訊,19類,標註資料的訓練集2975張,驗證集500張,測試集1525張,輸入影象大小為\(512\times 512\),沒有任何縮放裁剪,訓練迭代共40K次,batch size為2
PASCAL VOC 2012
隨機抽取1/8、1/4、1/2作為標註資料集,其餘的作為未標註資料集
baseline模型是沒有多出度融合的DeepLab-v2模型
Cityscapes
Comparisons with state-of-the-art methods
這是和第一個提出用GAN思想做語義分割那篇文章進行比較,本文與其不同點有二:1)本文是通用的網路結構,而[29]為不同資料集制定了對應的網路結構;2)本文沒有額外使用RGB通道
與[33]、[40]的方法比較,本文的模型在PASCAL VOC2012訓練集(1,464)中訓練,使用SBD資料集作為未標註資料
Hyper-parameter analysis
當\(T_{semi}=0\)時表示相信每個畫素