通俗|令人拍案叫絕的Wasserstein GAN 及程式碼(WGAN兩篇論文的中文詳細介紹)
今天第一篇為WGAN理論深入介紹。
在GAN的相關研究如火如荼甚至可以說是氾濫的今天,一篇新鮮出爐的arXiv論文Wassertein GAN 卻在 Reddit 的 Machine Learning 頻道火了,連Goodfellow都在帖子裡和大家熱烈討論,這篇論文究竟有什麼了不得的地方呢?
要知道自從2014年Ian Goodfellow提出以來,GAN就存在著訓練困難、生成器和判別器的loss無法指示訓練程序、生成樣本缺乏多樣性等問題。從那時起,很多論文都在嘗試解決,但是效果不盡人意,比如最有名的一個改進DCGAN依靠的是對判別器和生成器的架構進行實驗列舉,最終找到一組比較好的網路架構設定,但是實際上是治標不治本,沒有徹底解決問題。而今天的主角Wasserstein GAN(下面簡稱WGAN)成功地做到了以下爆炸性的幾點:
-
徹底解決GAN訓練不穩定的問題,不再需要小心平衡生成器和判別器的訓練程度
-
基本解決了collapse mode的問題,確保了生成樣本的多樣性
-
訓練過程中終於有一個像交叉熵、準確率這樣的數值來指示訓練的程序,這個數值越小代表GAN訓練得越好,代表生成器產生的影象質量越高(如題圖所示)
-
以上一切好處不需要精心設計的網路架構,最簡單的多層全連線網路就可以做到
那以上好處來自哪裡?這就是令人拍案叫絕的部分了——實際上作者整整花了兩篇論文,在第一篇《Towards Principled Methods for Training Generative Adversarial Networks》裡面推了一堆公式定理,從理論上分析了原始GAN的問題所在,從而針對性地給出了改進要點;在這第二篇《Wassertein GAN》裡面,又再從這個改進點出發推了一堆公式定理,最終給出了改進的演算法實現流程,而改進後相比原始GAN的演算法實現流程卻只改了四點:
-
判別器最後一層去掉sigmoid
-
生成器和判別器的loss不取log
-
每次更新判別器的引數之後把它們的絕對值截斷到不超過一個固定常數c
-
不要用基於動量的優化演算法(包括momentum和Adam),推薦RMSProp,SGD也行
演算法截圖如下:
改動是如此簡單,效果卻驚人地好,以至於Reddit上不少人在感嘆:就這樣?沒有別的了? 太簡單了吧!這些反應讓我想起了一個頗有年頭的雞湯段子,說是一個工程師在電機外殼上用粉筆劃了一條線排除了故障,要價一萬美元——畫一條線,1美元;知道在哪畫線,9999美元。上面這四點改進就是作者Martin Arjovsky劃的簡簡單單四條線,對於工程實現便已足夠,但是知道在哪劃線,背後卻是精巧的數學分析,而這也是本文想要整理的內容。
本文內容分為五個部分:
-
原始GAN究竟出了什麼問題?(此部分較長)
-
WGAN之前的一個過渡解決方案
-
Wasserstein距離的優越性質
-
從Wasserstein距離到WGAN
-
總結
理解原文的很多公式定理需要對測度論、 拓撲學等數學知識有所掌握,本文會從直觀的角度對每一個重要公式進行解讀,有時通過一些低維的例子幫助讀者理解數學背後的思想,所以不免會失於嚴謹,如有引喻不當之處,歡迎在評論中指出。
以下簡稱 Wassertein GAN 為“WGAN本作”,簡稱 Towards Principled Methods for Training Generative Adversarial Networks 為“WGAN前作”。
WGAN 原始碼實現:https://github.com/martinarjovsky/WassersteinGAN
第一部分:原始GAN究竟出了什麼問題?
回顧一下,原始GAN中判別器要最小化如下損失函式,儘可能把真實樣本分為正例,生成樣本分為負例:
其中 Pr 是真實樣本分佈,Pg 是由生成器產生的樣本分佈。對於生成器,Goodfellow 一開始提出來一個損失函式,後來又提出了一個改進的損失函式,分別是
後者在 WGAN 兩篇論文中稱為“the - log D alternative”或“the - log D trick”。WGAN 前作分別分析了這兩種形式的原始GAN各自的問題所在,下面分別說明。
第一種原始GAN形式的問題
一句話概括:判別器越好,生成器梯度消失越嚴重。WGAN 前作從兩個角度進行了論證,第一個角度是從生成器的等價損失函式切入的。
首先從公式1可以得到,在生成器 G 固定引數時最優的判別器 D 應該是什麼。對於一個具體的樣本,它可能來自真實分佈也可能來自生成分佈,它對公式1損失函式的貢獻是
令其關於D(x)的導數為0,得
化簡得最優判別器為:
這個結果從直觀上很容易理解,就是看一個樣本 x 來自真實分佈和生成分佈的可能性的相對比例。如果 Pr(x) = 0 且 Pg(x) ≠ 0,最優判別器就應該非常自信地給出概率0;如果 Pr(x) = Pg(x),說明該樣本是真是假的可能性剛好一半一半,此時最優判別器也應該給出概率0.5。
然而GAN訓練有一個trick,就是別把判別器訓練得太好,否則在實驗中生成器會完全學不動(loss降不下去),為了探究背後的原因,我們就可以看看在極端情況——判別器最優時,生成器的損失函式變成什麼。給公式2加上一個不依賴於生成器的項,使之變成
注意,最小化這個損失函式等價於最小化公式2,而且它剛好是判別器損失函式的反。代入最優判別器即公式4,再進行簡單的變換可以得到
變換成這個樣子是為了引入 Kullback–Leibler divergence(簡稱KL散度)和 Jensen-Shannon divergence(簡稱JS散度)這兩個重要的相似度衡量指標,後面的主角之一Wasserstein距離,就是要來吊打它們兩個的。所以接下來介紹這兩個重要的配角——KL散度和JS散度:
於是公式5就可以繼續寫成
到這裡讀者可以先喘一口氣,看看目前得到了什麼結論:根據原始GAN定義的判別器loss,我們可以得到最優判別器的形式;而在最優判別器的下,我們可以把原始GAN定義的生成器loss等價變換為最小化真實分佈Pr與生成分佈Pg之間的JS散度。我們越訓練判別器,它就越接近最優,最小化生成器的loss也就會越近似於最小化Pr和Pg之間的JS散度。
問題就出在這個JS散度上。我們會希望如果兩個分佈之間越接近它們的JS散度越小,我們通過優化JS散度就能將Pg“拉向”Pr,最終以假亂真。這個希望在兩個分佈有所重疊的時候是成立的,但是如果兩個分佈完全沒有重疊的部分,或者它們重疊的部分可忽略(下面解釋什麼叫可忽略),它們的JS散度是多少呢?
答案是log2,因為對於任意一個x只有四種可能:
第一種對計算JS散度無貢獻,第二種情況由於重疊部分可忽略所以貢獻也為0,第三種情況對公式7右邊第一個項的貢獻是
第四種情況與之類似,所以最終
換句話說,無論 Pr 跟 Pg 是遠在天邊,還是近在眼前,只要它們倆沒有一點重疊或者重疊部分可忽略,JS散度就固定是常數
log2,而這對於梯度下降方法意味著——梯度為0!此時對於最優判別器來說,生成器肯定是得不到一丁點梯度資訊的;即使對於接近最優的判別器來說,生成器也有很大機會面臨梯度消失的問題。
但是 Pr 與 Pg 不重疊或重疊部分可忽略的可能性有多大?不嚴謹的答案是:非常大。比較嚴謹的答案是:當 Pr 與 Pg 的支撐集(support)是高維空間中的低維流形(manifold)時,Pr 與 Pg 重疊部分測度(measure)為0的概率為1。
不用被奇怪的術語嚇得關掉頁面,雖然論文給出的是嚴格的數學表述,但是直觀上其實很容易理解。首先簡單介紹一下這幾個概念:
不用被奇怪的術語嚇得關掉頁面,雖然論文給出的是嚴格的數學表述,但是直觀上其實很容易理解。首先簡單介紹一下這幾個概念:
-
支撐集(support)其實就是函式的非零部分子集,比如ReLU函式的支撐集就是(0,+∞),一個概率分佈的支撐集就是所有概率密度非零部分的集合。
-
流形(manifold)是高維空間中曲線、曲面概念的拓廣,我們可以在低維上直觀理解這個概念,比如我們說三維空間中的一個曲面是一個二維流形,因為它的本質維度(intrinsic dimension)只有2,一個點在這個二維流形上移動只有兩個方向的自由度。同理,三維空間或者二維空間中的一條曲線都是一個一維流形。
-
測度(measure)是高維空間中長度、面積、體積概念的拓廣,可以理解為“超體積”。
回過頭來看第一句話,“當 Pr 與 Pg 的支撐集是高維空間中的低維流形時”,基本上是成立的。原因是GAN中的生成器一般是從某個低維(比如100維)的隨機分佈中取樣出一個編碼向量,再經過一個神經網路生成出一個高維樣本(比如64x64的圖片就有4096維)。當生成器的引數固定時,生成樣本的概率分佈雖然是定義在4096維的空間上,但它本身所有可能產生的變化已經被那個100維的隨機分佈限定了,其本質維度就是100,再考慮到神經網路帶來的對映降維,最終可能比100還小,所以生成樣本分佈的支撐集就在4096維空間中構成一個最多100維的低維流形,“撐不滿”整個高維空間。
“撐不滿”就會導致真實分佈與生成分佈難以“碰到面”,這很容易在二維空間中理解:一方面,二維平面中隨機取兩條曲線,它們之間剛好存在重疊線段的概率為0;另一方面,雖然它們很大可能會存在交叉點,但是相比於兩條曲線而言,交叉點比曲線低一個維度,長度(測度)為0,可忽略。三維空間中也是類似的,隨機取兩個曲面,它們之間最多就是比較有可能存在交叉線,但是交叉線比曲面低一個維度,面積(測度)是0,可忽略。從低維空間拓展到高維空間,就有了如下邏輯:因為一開始生成器隨機初始化,所以 Pg 幾乎不可能與Pr有什麼關聯,所以它們的支撐集之間的重疊部分要麼不存在,要麼就比 Pr 和 Pg 的最小維度還要低至少一個維度,故而測度為0。所謂“重疊部分測度為0”,就是上文所言“不重疊或者重疊部分可忽略”的意思。
我們就得到了WGAN前作中關於生成器梯度消失的第一個論證:在(近似)最優判別器下,最小化生成器的loss等價於最小化 Pr 與 Pg 之間的JS散度,而由於 Pr 與 Pg 幾乎不可能有不可忽略的重疊,所以無論它們相距多遠JS散度都是常數 log2,最終導致生成器的梯度(近似)為0,梯度消失。
接著作者寫了很多公式定理從第二個角度進行論證,但是背後的思想也可以直觀地解釋:
-
首先,Pr 與 Pg 之間幾乎不可能有不可忽略的重疊,所以無論它們之間的“縫隙”多狹小,都肯定存在一個最優分割曲面把它們隔開,最多就是在那些可忽略的重疊處隔不開而已。
-
由於判別器作為一個神經網路可以無限擬合這個分隔曲面,所以存在一個最優判別器,對幾乎所有真實樣本給出概率1,對幾乎所有生成樣本給出概率0,而那些隔不開的部分就是難以被最優判別器分類的樣本,但是它們的測度為0,可忽略。
-
最優判別器在真實分佈和生成分佈的支撐集上給出的概率都是常數(1和0),導致生成器的loss梯度為0,梯度消失。
有了這些理論分析,原始GAN不穩定的原因就徹底清楚了:判別器訓練得太好,生成器梯度消失,生成器loss降不下去;判別器訓練得不好,生成器梯度不準,四處亂跑。只有判別器訓練得不好不壞才行,但是這個火候又很難把握,甚至在同一輪訓練的前後不同階段這個火候都可能不一樣,所以GAN才那麼難訓練。
實驗輔證如下:
圖:WGAN前作Figure 2。先分別將DCGAN訓練1,20,25個epoch,然後固定生成器不動,判別器重新隨機初始化從頭開始訓練,對於第一種形式的生成器loss產生的梯度可以打印出其尺度的變化曲線,可以看到隨著判別器的訓練,生成器的梯度均迅速衰減。注意y軸是對數座標軸。
第二種原始GAN形式的問題
一句話概括:最小化第二種生成器loss函式,會等價於最小化一個不合理的距離衡量,導致兩個問題,一是梯度不穩定,二是collapse mode即多樣性不足。WGAN前作又是從兩個角度進行了論證,下面只說第一個角度,因為對於第二個角度我難以找到一個直觀的解釋方式,感興趣的讀者還是去看論文吧。
如前文所說,Ian Goodfellow提出的“- log D trick”是把生成器loss改成
上文推導已經得到在最優判別器D*下
我們可以把KL散度(注意下面是先g後r)變換成含D*的形式:
由公式3,9,10可得最小化目標的等價變形
注意上式最後兩項不依賴於生成器G,最終得到最小化公式3等價於最小化
這個等價最小化目標存在兩個嚴重的問題。第一是它同時要最小化生成分佈與真實分佈的KL散度,卻又要最大化兩者的JS散度,一個要拉近,一個卻要推遠!這在直觀上非常荒謬,在數值上則會導致梯度不穩定,這是後面那個JS散度項的毛病。
第二,即便是前面那個正常的KL散度項也有毛病。因為KL散度不是一個對稱的衡量,KL(Pg || Pr) 與 KL(Pr || Pg) 是有差別的。以前者為例
換言之,KL(Pg || Pr) 對於上面兩種錯誤的懲罰是不一樣的,第一種錯誤對應的是“生成器沒能生成真實的樣本”,懲罰微小;第二種錯誤對應的是“生成器生成了不真實的樣本” ,懲罰巨大。第一種錯誤對應的是缺乏多樣性,第二種錯誤對應的是缺乏準確性。這一放一打之下,生成器寧可多生成一些重複但是很“安全”的樣本,也不願意去生成多樣性的樣本,因為那樣一不小心就會產生第二種錯誤,得不償失。這種現象就是大家常說的collapse mode。
第一部分小結:在原始GAN的(近似)最優判別器下,第一種生成器loss面臨梯度消失問題,第二種生成器loss面臨優化目標荒謬、梯度不穩定、對多樣性與準確性懲罰不平衡導致mode collapse這幾個問題。
實驗輔證如下: