1. 程式人生 > 其它 >“半監督”異常檢測方法GANomaly

“半監督”異常檢測方法GANomaly

原文標題:GANomaly: Semi-Supervised Anomaly Detection via Adversarial Training

原文連結:

背景介紹

異常檢測是計算機視覺領域一個比較經典的問題,它旨在區分正常樣本(下文稱為OK樣本)和非正常樣本(下文稱為NG樣本)。乍一看,像是普通的二分類問題。其實不然,異常檢測有一個內在的屬性:樣本極其不平衡,即OK樣本非常多,NG樣本非常少。極端情況,訓練階段見不到任何NG樣本,該問題就變成了單分類問題了(本文也將這種只有OK樣本而沒有NG樣本參與訓練的情況稱為“半監督”,筆者認為是不妥的)。本文提出的GANomaly方法,就是針對這種極端情況的。

由於異常檢測問題中NG樣本通常比較少,直接學習能區分NG樣本的模型是很困難的。既然NG樣本不可靠,那大家自然會想到採取相反的思路,學習能區分OK樣本的模型就好,只要跟OK長得不像的就認為是NG的。自編碼器(Autoencoder)是異常檢測中比較經典的一種方法。它的解決思路是採用儘可能多的OK樣本去學習一個自編碼模型,由於該模型見過足夠多的OK樣本,因此它能夠很好地將OK樣本重建出來,而NG樣本它是沒見過的,因此它沒法很好地重建出來。推理階段,通過輸入圖片的重建誤差,就可以區分出OK和NG樣本了。但是,該方法非常容易受噪聲影響,需要在自編碼器上加各種約束,才能得到一個可用的異常檢測模型。

主要思想

如上圖所示,不同於一般的基於自編碼器的方法,本文采用的是一個編碼器(Encoder1)-解碼器(Decoder)-編碼器(Encoder2)的網路結構,同時學習“原圖->重建圖”和“原圖的編碼->重建圖的編碼”兩個對映關係。該方法不僅對生成的圖片外觀(圖片->圖片)做了的約束,也對圖片內容(圖片編碼->圖片編碼)做了約束。另外,該方法還引入了生成對抗網路(GAN)中的對抗訓練思想。這裡,作者將Encoder1-Decoder-Encoder2當成生成網路G-Net,又定義了一個判別網路D-Net,通過交替訓練生成網路和對抗網路,最終學到一個比較好的生成網路。

推理階段,該方法也不同於一般的基於自編碼器的異常檢測方法。最後用於推斷異常的不是原圖和重建圖的差異,而是第一部分編碼器產生的隱空間特徵(原圖的編碼)和第二部分編碼器產生的隱空間特徵(重建圖的編碼)的差異。這種方法更關注圖片實質內容的差異,對圖片中的微小變化不敏感,因而能解決自編碼器中易受噪聲影響的問題,魯棒性更好。

筆者認為本文的主要貢獻在於提出了這個Encoder1-Decoder-Encoder2的結構,D-Net只是錦上添花的。因為即便沒有D-Net和對抗訓練的思想,好好調引數該方法也可以work。

網路結構

本文網路結構包含三個子網路。

第一個子網路是一個常規的碗形的自編碼器,它的作用是用於重建輸入的OK影象。該自編碼器結構的設計參考了DCGAN,具體而言,該自編碼器的解碼器部分(Decoder)和DCGAN的生成網路幾乎是一樣的,即從一個n維的向量(bottleneck1)對映到一張3通道的圖片,如下圖所示。該自編碼器的編碼器部分(Encoder1)則是編碼器的逆過程,即從一張3通道的圖片對映到一個n維的向量。

第二個子網路是一個編碼網路(Encoder2),它的作用是將第一個子網路重建出來的圖片再壓縮為一個n維的向量(bottleneck2)。雖然Encoder2採用的結構和Encoder1是一樣的,但它們的引數顯然是不一樣的。這麼一個重複的結構看起來沒有什麼了不起的,但筆者認為該結構是本文思想中最為核心的地方,它摒棄了絕大部分基於自編碼器的異常檢測方法常用的通過對比原圖和重建圖的差異來推斷異常的方式,採用了一種新的通過對比原圖和重建圖在高一層抽象空間中的差異來推斷異常的方式,而這一層額外的抽象可以使其大大提高抗噪聲干擾的能力,學到更加魯棒的異常檢測模型。

文章中第一個子網路和第二個子網路共同構成了生成對抗網路中的生成網路(G-Net),聽起來有點費解。其實可以換個思路想,第一個子網路就是一箇中規中矩的生成網路,第二個子網路只是它的一個約束條件而已。

第三個子網路是一個判別網路(D-Net),它的作用是用於區分原圖和重建圖(G-Net生成的圖片),即要將原圖判別為真,將重建圖判別為假。它的結構和第一個子網路的解碼網路是一樣的。D-Net的引入,是為了引入對抗訓練思想,旨在學到更好的G-Net。

綜上,該文章設計的網路結構事實上比較簡單,就是一個Encoder和一個Decoder,只是通過不同的組合,生成了三部分的子網路。接下來將介紹每部分子網路採用的損失函式。

損失函式

本文包含三個子網路,每個子網路對應一個損失函式。由於文章中寫的損失函式和作者公佈的程式碼中的損失函式有些出入,筆者認為程式碼中的損失函式更為合理,因此下文介紹的都是程式碼中的損失函式。

第一個子網路的損失是自編碼器的重建損失,這裡借鑑了pix2pix文章中生成網路的損失,採用的是L1損失,而不是L2損失。因為採用L2損失生成的影象通常比採用L1生成的影象要模糊。

第二個子網路的損失是編碼網路的損失,這裡需要比對的是原圖和重建圖在高一層抽象空間中的差異,即兩個bottleneck(上文中的bottleneck1和bottleneck2)間的差異,採用的是L2損失。

第三個子網路的損失是常規的GAN中判別網路的損失,這裡採用的是二分類的交叉熵損失。

正常來說,採用第一個子網路的生成損失和第三個子網路的判別損失就能生成比較不錯的圖片了,但是這篇文章主要解決的是異常檢測問題。異常是圖片集的特性,採用畫素級的損失(原圖和重建圖的差異)來推斷是不夠合理的,因而引入了第二個子網路的編碼損失,文章中最後用於推斷的也是該損失。

訓練

本文采用的訓練策略和常規的GAN一樣的,即交替地優化D-Net和G-Net。

優化D-Net時,採用的損失為上述第三個子網路的損失,即:

這裡的輸入。雖然這裡的需要通過G-Net來生成,但是訓練D-Net時,G-Net的引數是固定的。

優化G-Net時,採用的損失比較複雜:

主體損失為重建損失,編碼損失為重建損失的一個約束,對抗損失則是用來和D-Net博弈。需要注意的一點是,這裡的對抗損失的輸入物件和優化D-Net時的輸入物件是不一樣的,這裡的,這和常規GAN的訓練是一致的。

推斷

前面提到,本文采用的推斷方式和一般的基於自編碼器的異常檢測方法是不一樣的。這裡推斷以來的不是重建損失,而是編碼損失。具體而言,網路訓練收斂以後,我們可以計算得到所有OK樣本中的值,選取其中最大的作為判別閾值。推斷時,給定一張圖片,我們可以利用學好的網路,計算其值,如果它小於判別閾值則判斷為OK樣本(正常樣本),大於則判斷為NG樣本(異常樣本)。

實驗

要做基於GANomaly的異常檢測實驗,需要準備大量的OK樣本和少量的NG樣本。找不到合適的資料集怎麼辦?很簡單,隨便找個開源的分類資料集,將其中一個類別的樣本當作異常類別,其他所有類別的樣本當作正常樣本即可,文章中的實驗就是這麼幹的。具體試驗結果如下:

反正在效果上,GANomaly是超過了之前兩種代表性的方法。此外,作者還做了效能對比的實驗。事實上前面已經介紹了GANomaly的推斷方法,就是一個簡單的前向傳播和一個對比閾值的過程,因此速度非常快。具體結果如下:

可以看出,計算效能上,GANomaly表現也是非常不錯的。

總結

雖然異常檢測在資料探勘領域很早就有人做了,但是計算機視覺領域的相關研究還相對較少。另外,GAN這幾年非常火,GAN到底能不能做異常檢測,還沒有太多人嘗試過。本文算是一個比較成功地將GAN用到異常檢測的例子。