1. 程式人生 > 實用技巧 >深度學習《GAN模型學習》

深度學習《GAN模型學習》

前言:今天我們來一起學習下GAN神經網路,上一篇博文我先用pytorch運行了幾個網上的程式碼例子,用於生成MNIST圖片,現在我才反過來寫GAN的學習,這次反了過來,效果也是很顯而易見的,起碼有個直觀的感覺,這一篇迴歸基本知識,介紹下GAN是什麼?本文不會引用論文很複雜的公式。
論文參考自:https://arxiv.org/pdf/1406.2661.pdf
建議想學習的可以去論文學學,畢竟是大神之作,然後形成自己的認知。

一:什麼是GAN?
在以前學習機器學習演算法的時候,比如線性迴歸,邏輯迴歸,SVM等,或者是高斯混合模型的時候,我們心裡已經大致知道了大概資料的特徵的樣子,然後用某個合適的模型去根據資料特徵做出一些分類,在BP神經網路中,比如識別MNIST數字的例子,我們人眼也能看出來數字的特徵(‘1’就是一個豎線),然後通過卷積、池化、全連線等方式去根據特徵做出分類,誒呦不錯哦,這個模型能根據已知的資料的分佈特徵(對於圖片來說,畫素就是資料)去做出一些機智的判斷,大大提升效率。

但是如果我說,既然你的模型這麼聰明,模型也看過的數字圖片,它能根據資料的分佈特徵找到不同的數字類別,那模型你能跟我說一下數字‘1’到‘9’的特徵是啥麼?你能不能製造出這些類似的數字影象呢?(比如我說數字‘1’,你能畫出來它的樣子麼?)

這就是GAN所做的工作,根據現實的資料,去學習出資料的分佈特徵,從而製造出類似的資料,這個模型就叫做GAN。

二:慢慢深入
前段時間有個熱播的電視劇《古董局中局》,講的就是古董造假行業和鑑別真偽世家之間的無休無止糾纏錯亂的故事。
舉個例子,小D是一名鑑別師,小G是一名造假師,他倆自出生開始就是宿命相對。

小D出生古董行世家,從小看了很多真實的古玩名跡,金石玉器,擁有了一定程度的鑑別古董真偽的能力。小G也能簡單的造出一些常人都能識破的假物件,當然小D更能識破,讓後告訴小G你拿給我的都是假的。

小G很刻苦,繼續訓練,虛心接受小D的評判結果和反饋更新,不斷提升造假能力,而小D也沒閒著,也在不斷提升鑑別能力,相互促進相互提高,最後小G的造假能力大大提升,足夠以假亂真,直到小D再也分不清楚哪個是真哪個是假,對每一個物件的判別真實度都是50%概率了,這個博弈過程就結束了。

GAN的模型也是有一個鑑別模型(Discriminator)和一個造假模型(Discriminator)組成,如下圖:
在這裡插入圖片描述

簡單描述下就是:D是一個判別器,X是真實的資料,D(X)表示X是真實資料的概率(同樣的1- D(X)則表示判定X是假資料的概率),範圍是0~1。另一方面,Z是噪聲隨機數,維度任意,G是構造器,經過G後的G(X)是和X相同維度的資料,是假資料Xg = G(X),然後同樣交給D模型去判別真偽。真和假是兩個分類,我們是不是想到了邏輯迴歸?稍等下我們會涉及到。我們來分析下各個模型的內心活動啊。

1)對於D來說,D想要把真實資料X儘可能都識別成真實資料,也就是希望D(X)儘量越大越好,最好趨近於1。相反對於構造的資料G(X),希望鑑別出來都是假,也就是希望D(G(X))儘量越小越好,最好趨近於0,希望越小的概率認為真實的就等希望同於越大的概率認為是虛假的,所以換句話說就是對於構造的資料G(X),希望鑑別出來都是假,也就是希望1-D(G(X))儘量越大越好,最好趨近於1。

2)對於G來說,真實資料G管不著,我只能希望構造生成的資料能騙過D認為是真實的,使得D(X)逼近於1。這就是二者博弈的過程,最終達到一個均衡點,使得D(X)傻傻分不清了,什麼都是50%的概率進行判別。

三:詳細計算
下面給出損失函式,截圖自論文:
在這裡插入圖片描述

這裡簡單說一下,這裡是倆個損失函式的合併寫法,現在我拆開分析,就更加清晰了。
1)對於D來說希望把真實資料儘量大概率預測為真,把構造資料儘量大概率預測假,上述也說了,1-D(x)就是預測為假資料的概率。
在這裡插入圖片描述

假設真實資料的標籤是1,虛假資料的標籤是0。上述式子可以合併,給出損失函式,按照邏輯迴歸的代價函式:在這裡插入圖片描述

對D來說如果想最小化代價函式L,那麼就是最大化V(D,G)。

2)對G來說希望騙過D,使得D認為G構造的資料儘可能大概率認為真實的,即D(G(X))儘可能等於1,也可以說是希望D判別出是假資料的概率越小越好,1-D(x)儘可能等於0,只需要這點就夠了,真實資料不歸G管,如下。
在這裡插入圖片描述

G希望P越大越好,越接近1越好,即F越小越好,越接近0越好,二者是熵等價的。由於我們出於習慣一般都是求最小化,而且為了和D的式子保持統一風格,因為選F計算的概率,又因為G只管Z資料,所以G模型的樣本都是來自假的資料。
在這裡插入圖片描述

二者式子歸併在一起就是論文中給出的式子了,我感覺它就是個抽象的定義。希望上述的公式分析,對下面展示的梯度下降過程有一定的幫助,具體的論文中給的梯度下降過程是如下所示:
在這裡插入圖片描述

上述注意幾點啊(前面已經分析了代價和函式,請一起參考)。
1:每一輪迭代都是先訓練D,且D會訓練K輪,K=1是最低的要求。然後再訓練G。

2:對D來說,因為論文中它給的式子是V(D,G),對V(D,G)求偏導數,所以引數更新的時候需要加上這個偏導,而我們習慣上是對Loss函式求偏導,剛好Loss函式就是V(D,G)的相反數,使用正常的方式對Loss求偏導的話,結果也是一樣的。

3:對G來說,V(D,G)和Loss的函式是一樣的。都是求最小化,因為求完偏導數。需要做減法。

四:後續學習介紹:
在影象處理領域,最好最多的模型是CNN,怎麼將二者結合起來呢?之前我也在Pytorch《GAN模型生成MNIST數字》使用了CNN的構造生成模型,有興趣可以看看。

DCGAN(Deep Convolutional Generative Adversarial Networks)模型就是一個在這方面非常好的嘗試,也是比較流行使用的模型方式,下一篇文章我將好好學習這個,並且給出實際的影象生成的例子,請見下一篇Pytorch《DCGAN模型》。

推薦個github地址,這裡有很多GAN相關的論文:
https://github.com/zhangqianhui/AdversarialNetsPapers