1. 程式人生 > 實用技巧 >一文看懂GAN

一文看懂GAN

生成對抗網路(Generative Adversarial Network,GAN)自從2014年由Ian Goodfellow提出以來,一直受到了廣泛的關注和研究,在短短几年時間內獲得了快速的發展,並在許多應用場景中取得了顯著的成果。

常用的模型大多可以分為兩類,生成模型(Generative Model)和判別模型(Discriminative Model)。前者的輸出是“看起來很逼真”的虛假資料,例如合成人臉圖片、撰寫新聞報道,應當和真實的資料儘可能相似;相比之下,後者一般實現一個從複雜結構資料到簡單判別結果的對映(Mapping),例如判斷圖片是貓還是狗、文字所表達的情緒為積極或者消極

根據以上的定義,GAN顯然屬於生成模型,因為我們希望通過GAN生成一些以假亂真的合成數據。進一步而言,生成模型可以分為無條件(Unconditional)和有條件(Conditional)兩類,其中前者的生成結果完全隨機,後者的生成結果則是可控的,例如生成男性或女性的人臉圖片、撰寫符合某個話題的新聞報道。

01 模型優化

在正式討論GAN的原理之前,我們先介紹一些關於模型優化的基本知識。如果你不是小白,則完全可以跳過這一部分

模型(Model)就好比一臺機器,通常用於完成某一項任務,有輸入(Input)也有輸出(Output),例如在貓狗圖片二分類這一任務中,輸入為各種各樣的貓狗圖片,輸出為貓或狗這兩個選擇之一

模型會包含一些可調引數(Trainable Parameters),這些引數的取值都是可變的,就好比一臺電壓力鍋,可以設定不同的壓力和時長。對於同樣的食材,當壓力和時長設定得不一樣時,煮出來的結果也會不一樣。類似的,對於同樣的輸入,當模型的引數取不同的值時,得到的輸出也將不同

一般而言,我們會通過某種初始化方法(Initializer)為模型的每個引數設定一個初始值,例如隨機初始化。回到貓狗二分類的問題上來,初始化之後的模型,肯定是無法對於每張輸入圖片,都輸出正確的分類的結果

為了讓模型學會貓狗分類的能力,我們會準備一些標註資料(Labeled Data),例如1萬張貓狗圖片,以及每張圖片所對應的標註(Label),比如用0代表貓、用1代表狗,然後通過這些標註資料來訓練(train)模型。這種每個樣本都提供了正確答案(Ground Truth)的訓練模式,稱為有監督學習(Supervised Learning)

具體如何訓練呢?我們可以每次拿4張圖片,輸入模型並得到相應的輸出,這些輸出結果中有對的也有錯的,我們希望它們儘可能地接近正確答案。可以使用某種損失函式(Loss Function)來衡量兩者之間的差距,例如二分類問題中最常用的二元交叉熵(Binary Cross Entropy),損失函式越小,則表明模型的輸出越接近正確答案

於是我們對模型說,你看啊,還差那麼多,趕緊把引數調一下!具體怎麼調呢?唯一的原則就是,每個引數調整之後,都應當使當前的損失函式往減小的方向變化,這一點在數學上可以通過求導來實現。如此一來,我們便完成了一次模型優化(Optimization),也稱為一次迭代(Iteration),我們的模型變得更聰明瞭,對於同樣的輸入,輸出結果將更接近正確答案

上面我們是每次拿4張圖片來訓練,如果每次只用一張圖片,那麼優點是迭代一次的時間將縮短,但缺點是訓練可能不穩定,畢竟在做決定時,我們一般會同時聽取多方意見。另一個極端是,每次都用全部圖片來訓練,這樣得到的引數調整方案會更加靠譜,但迭代一次所需的時間也會大大增加

因此,通常我們會每次拿一批(Batch)資料來訓練,一批資料中所包含的圖片數量稱為批大小(Batch Size)。在不同的優化任務中,批大小如何設定依賴於經驗,但一般會設定為2的冪,例如2、4、8、16、32、64等。假設一共有1萬張圖片,批大小設為16,那麼經過10000/16=625次迭代後,所有的資料都過了一遍,我們稱訓練了一輪(Epoch)

可想而知,隨著訓練的進行,我們的模型將變得越來越聰明,那麼應當如何評估模型的效能呢?一種常用的做法是將標註資料劃分為訓練集(Train Set)和測試集(Test Set),僅用訓練集來優化模型,然後在測試集上評估模型,就好比一共有100套卷子,80套用來練習,20套用來考試。評估的時候會用到一些評估指標(Evaluation Metric),例如分類問題中常用的正確率(Accuracy)、準確率(Precision)和召回率(Recall)等

還有一個待解決的問題,便是何時停止訓練,我們可以只訓練20輪,當然也可以訓練100輪。一般而言,如果訓練輪數過少,則模型可能學得不夠,即還有優化的空間;如果訓練輪數過多,則模型可能學得過頭了,具體表現為在訓練集上效能很好,但在測試集上效能很差。至於如何找到那個恰到好處的點,則需要了解欠擬合、過擬合、模型複雜度、正則化等內容,這裡就不再展開介紹了。

02 圖片的表示

圖片是由一個個畫素點組成的,一張高度為H、寬度為W的圖片,共包括H*W個畫素點。如果是RGB彩色圖,則每個畫素點包括三個顏色通道,即紅(Red)、綠(Green)、藍(Blue),每個顏色通道的取值都是0~255之間的整數,這樣一來,就一共有256*256*256=16,777,216種不同的畫素值,也就是所謂的“顏色”

GAN在CV界所取得的進展和成果遠遠多於自然語言處理(Natural Language Processing,NLP),一個主要原因就是圖片的表示是“連續”的,即當畫素值發生微小變化時,視覺上並不會察覺出明顯的區別。相比之下,NLP任務中一般會將字或詞作為最基礎的語義單元,而字和詞的表示是“離散”的,即我們很難統一規定,當一個字或詞發生微小變化時,對應的語義是哪一個其他的字或詞。字或詞的這種“離散跳變性”,無疑加大了GAN訓練時的困難和不穩定性,從而導致GAN在NLP領域中的發展和應用相對較少。


03 生成和對抗

在上面的貓狗圖片二分類例子中,模型需要實現從圖片到01二分類結果之間的對映,一般情況下只需要一個模組(Module)即可。相比之下,GAN包括兩個模組,生成器(Generator,G)和判別器(Discriminator,D)

G的任務是隨機生成以假亂真的合成數據。為了滿足隨機性,通常我們會使用多個隨機數作為G的輸入,例如100個從標準正態分佈(Random Normal Distribution)中隨機取樣得到的隨機數,記作隨機噪音(Random Noise)z,輸出則是和真實圖片相同解析度的圖片

D的任務是區分真假,即判斷一張圖片到底是真實圖片,還是G合成的虛假圖片。因此,D的輸入是圖片,輸出是一個分數D(·),分值越高表示輸入圖片越真實。理想情況下,D應當對於所有真實圖片都輸出高分,對於所有虛假圖片都輸出低分,如此一來,便可以完美實現判別真假的目標

回到我們之前介紹的模型優化上來,在具體實現上,G和D都可以通過神經網路(Neural Network)來實現,並且都包含大量的可調引數。在經過初始化之後,G不具備任何的生成能力,輸出的虛假圖片和真實圖片相去甚遠;D也不具備任何的判別能力,對於真實或虛假圖片所輸出的分數沒有太大區別

現在開始模型的優化~在每次迭代中,我們隨機選擇一批真實圖片x,並隨機生成一批z,然後將z輸入G得到一批虛假圖片x'=G(z)。D的損失函式包括兩方面:第一是x對應的分數D(x)應當比較高,例如和1儘可能接近;第二是x'對應的分數D(x')應當比較低,例如和0儘可能接近。按照損失函式減小的方向,調整D的每一個引數,便完成了D的一次優化。在經過優化之後,D對於真實或虛假圖片所輸出的分數,就更加有區分度了

至於G,其目標是讓D誤以為x'是真實圖片,因此G的損失函式可以是D(x')和1之間的差距,這個差距越小,表明在D看來x'越真實。按照損失函式減小的方向,調整G的每一個引數,便完成了G的一次優化。在經過優化之後,G合成的虛假圖片,在D的判別下,就變得更加真實了。值得一提的是,在整個優化過程中,G並沒有接觸到真實圖片x,它也不在乎真實圖片x長什麼樣,只要合成的虛假圖片x'在D看來像真的就行

重複以上步驟,隨著優化的進行,D的判別能力越來越強,G的生成能力也越來越強,兩者互相博弈、共同進步。在理想的情況下,G最終可以生成和真實圖片難以區別的虛假圖片,如此一來,為了合成一些不存在的圖片,我們只需要隨機生成很多z,輸入G即可得到對應的合成結果。

以上就是GAN的基本原理,最後再貼一下Goodfellow論文裡所使用的損失函式

preview

04 合成結果

在具體實現上,如果對深度學習(Deep Learning)中常用的層(Layer)和運算元(Operator)比較熟悉,那麼應該可以很輕鬆地想到如何實現G和D,只要使得G將隨機向量對映為圖片,而D將圖片對映為數值即可。例如,在Goodfellow的論文中,一種方法是通過全連線層(Fully-Connected Layer,也稱為Dense)來實現G和D,另一種方法則是用二維卷積層(Conv2d)實現G,以及二維逆卷積層(Deconv2d)實現D

論文展示了在MNIST、TFD、CIFAR-10三個資料集上的實驗結果,其中最右一列表示和倒數第二列中的虛假樣本,最為接近的真實樣本。雖然這些結果在今天看起來不怎麼樣,但是在論文發表的2014年,能夠實現這樣的合成效果,已經足以引起相當大的轟動了

文章出處:https://zhuanlan.zhihu.com/p/157849976