1. 程式人生 > >從GAN到WGAN到WDGRL誤差函式的深入淺出解讀

從GAN到WGAN到WDGRL誤差函式的深入淺出解讀

這篇部落格分為三部分,先介紹GANloss函式,以及它存在的問題;接下來第二節介紹WGANloss函式由來,以及實現細節;最後介紹在程式中使用最多的WGAN-GPloss函式。

  1. 傳統GAN的訓練困難原因

傳統GAN的Loss,該Loss有些不足的地方,導致了GAN的訓練十分困難,表現為:

1、模式坍塌,即生成樣本的多樣性不足;

2、不穩定,收斂不了。

原因總結如下:

用KL Divergence和JS Divergence作為兩個概率的差異的衡量,最關鍵的問題是若兩個概率的支撐集不重疊,就無法讓那個引數化的、可移動的概率分佈慢慢地移動過來,以擬合目標分佈。(即KL散度和JS散度在兩個概率分佈沒有重疊的情況下,無法反應兩者之間的差異性,因此無法進行學習優化。)

GAN的誤差函式:

判別器D的loss函式:

判別器判斷真實樣本的得分D(x)越高越好,判斷生成樣本的得分D(G(Z))越低越好

生成器G的loss函式:

生成器的目標是生成的樣本在判別器得分D (G(Z))越高越好。

 

  1. WGAN

Wasserstein Distance:(兩個概率分佈的距離衡量指標)

定義如下:

第一句話的解釋很漂亮:

W(Pr,Pg)是這兩個概率分佈的距離,它是兩個在同一空間上(維度相同)的隨機變數x,y之差的範數均值的下确界。

下确界:某個集合X 的子集 E 下确界(

英語:infimum infima,記為inf E )是小於或等於的E 所有其他元素的最大元素,其不一定在E 內。

轉化為

f(x)是函式集 中的一個函式。 表示滿足1-Lipschitz條件的函式集。(Lipschitz條件是一個比通常連續更強的光滑性條件。直覺上,Lipschitz連續函式限制了函式改變的速度,符合利Lipschitz條件的函式的斜率,必小於一個稱為Lipschitz常數的實數)。

用K-Lipschitz條件代替:

Sup指上確界,inf指下确界

式要求得到上確界,上確界的具體函式形式我們不知道,但我們可以用神經網路來逼近它,這是判別器(Discriminator)的作用,也就是Discriminator網路充當了f(x)的角色,因此(4)等價於:

其中, 是樣本函式平均值

判別器D,目標是這個距離越大越好,

因此判別器的損失函式:

生成器只能調節生成器引數,不能調節判別器引數,因此

這個距離越小越好.

參考:https://blog.csdn.net/StreamRock/article/details/81138621

其中的pytorch原始碼,清楚地解釋瞭如果在程式中得到判別器和生成器的loss,其中WGAN對權重進行了修剪:

# Clip weights of discriminator

for p in discriminator.parameters():

p.data.clamp_(-opt.clip_value, opt.clip_value)

要保證fθ(x)滿足K-Lipschitz條件,夾逼了判別器的引數。

關於WGAN的loss函式,我發現這個總結更為精闢:

WGAN中,判別器D和生成器Gloss函式分別是:

  1. WGAN-GP

參考: https://blog.csdn.net/omnispace/article/details/77790497(解釋很精彩)

大部分程式中採用WGAN-GP(Gradient penalty)。

在引入梯度懲罰項之前,先介紹採用引數夾逼的方式存在的兩個問題:

 

  1. 判別器loss希望儘可能拉大真假樣本的分數差,然而weight clipping獨立地限制每一個網路引數的取值範圍,在這種情況下我們可以想象,最優的策略就是儘可能讓所有引數走極端,要麼取最大值(如0.01)要麼取最小值(如-0.01)。

這樣帶來的結果就是,判別器會非常傾向於學習一個簡單的對映函式(想想看,幾乎所有引數都是正負0.01,都已經可以直接視為一個二值神經網路了,太簡單了)。而作為一個深層神經網路來說,這實在是對自身強大擬合能力的巨大浪費!判別器沒能充分利用自身的模型能力,經過它回傳給生成器的梯度也會跟著變差。

  1. 第二個問題,weight clipping會導致很容易一不小心就梯度消失或者梯度爆炸。原因是判別器是一個多層網路,如果我們把clipping threshold設得稍微小了一點,每經過一層網路,梯度就變小一點點,多層之後就會指數衰減;反之,如果設得稍微大了一點,每經過一層網路,梯度變大一點點,多層之後就會指數爆炸。只有設得不大不小,才能讓生成器獲得恰到好處的回傳梯度,然而在實際應用中這個平衡區域可能很狹窄,就會給調參工作帶來麻煩。相比之下,gradient penalty就可以讓梯度在後向傳播的過程中保持平穩。

既然判別器希望儘可能拉大真假樣本的分數差距,那自然是希望梯度越大越好,變化幅度越大越好,所以判別器在充分訓練之後,其梯度norm其實就會是在K附近。知道了這一點,我們可以把上面的loss改成要求梯度normK越近越好,效果是類似的:

簡單地把K定為1,再跟WGAN原來的判別器loss加權合併,就得到新的判別器loss

三個loss項均是期望的形式,在實際中通過取樣的方式獲得。前面兩個期望的取樣我們都熟悉,第一個期望是從真樣本集裡面採,第二個期望是從生成器的噪聲輸入分佈取樣後,再由生成器對映到樣本空間。可是第三個分佈要求我們在整個樣本空間 上取樣,這完全不科學!由於所謂的維度災難問題,如果要通過取樣的方式在圖片或自然語言這樣的高維樣本空間中估計期望值,所需樣本量是指數級的,實際上沒法做到。

我們其實沒必要在整個樣本空間上施加Lipschitz限制,只要重點抓住生成樣本集中區域、真實樣本集中區域以及夾在它們中間的區域就行了。具體來說,我們先隨機採一對真假樣本,還有一個0-1的隨機數:

在 和 的連線上隨機插值取樣,

·  weight clipping是對樣本空間全域性生效,但因為是間接限制判別器的梯度norm,會導致一不小心就梯度消失或者梯度爆炸;

·  gradient penalty只對真假樣本集中區域、及其中間的過渡地帶生效,但因為是直接把判別器的梯度norm限制在1附近,所以梯度可控性非常強,容易調整到合適的尺度大小。

這個採用點的獲取可以用下圖表示:

  1. 從真實資料 PdataPdata 中取樣得到一個點
  2. 從生成器生成的資料 PGPG 中取樣得到一個點
  3. 為這兩個點連線
  4. 在線上隨機取樣得到一個點作為 Ppenalty的點。

 

注意:由於我們是對每個樣本獨立地施加梯度懲罰,所以判別器的模型架構中不能使用Batch Normalization,因為它會引入同個batch中不同樣本的相互依賴關係。如果需要的話,可以選擇其他normalization方法,如Layer Normalization、Weight Normalization和Instance Normalization,這些方法就不會引入樣本之間的依賴。論文推薦的是Layer Normalization。