1. 程式人生 > 實用技巧 >生成式對抗網路(Generative Adversarial Networks, GANs)

生成式對抗網路(Generative Adversarial Networks, GANs)

1 原始的 GANs

1.1 GANs 的結構

GANs 的結果圖如下所示:

在這裡插入圖片描述

生成式對抗網路 GANs 最重要的兩個部分為:

  • 生成器(Generator) :用於生成“假”樣本。生成器從先驗分佈中採得隨機訊號,經過神經網路的變換,得到模擬樣本。
  • 判別器(Discriminator) :用於判斷輸入的樣本是真實的還是合成的。判別器既接收來自實際資料集的真實樣本,也接收來自生成器的模擬樣本,判別器需要判斷輸入的樣本是真實資料還是生成器的模擬(假)資料。

從上面可以看出,生成器和判別器是對抗的關係,生成器要儘可能生成出讓判別器失敗的樣本,而判別器要儘可能識別出生成器的假樣本。GANs 就是通過這種對抗的關係,讓生成器和判別器不斷提升。理想狀態下,生成器和判別器最終能達到一種平衡,兩者都趨於完美,都沒有更進一步的空間。

1.2 GANs 的訓練過程

GANs 採用生成器和判別器交替優化的方式:

(1)固定生成器 G G G,訓練判別器 D D D

固定生成器 G G G,然後利用生成器隨機模擬產生樣本 G ( z ) G(z) G(z) 作為負樣本( z z z 是一個隨機向量),並從真實資料集中取樣獲得正樣本 X X X,將這些正負樣本輸入到判別器 D D D 中,根據判別器的輸出(即 D ( X ) D(X) D(X) D ( G ( z ) ) D(G(z)) D(G(z)) )和樣本標籤來計算誤差,最後利誤差反向傳播演算法來更新判別器的引數,如下圖所示

在這裡插入圖片描述

(2)固定判別器 D D

D,訓練生成器 G G G

固定判別器 D D D,然後利用當前生成器 G G G 隨機模擬產生樣本 G ( z ) G(z) G(z),並輸入到判別器 D D D 中;根據判別器的輸出 D ( G ( z ) ) D(G(z)) D(G(z)) 和樣本標籤來計算誤差,最後利用誤差反向傳播演算法來更新生成器 G G G 的引數,如下圖所示:

在這裡插入圖片描述

1.3 GANs 的訓練模型

先給出 GANs 的公式:

min ⁡ G max ⁡ D V ( D , G ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E z ∼ p z ( z ) [ log ⁡ D ( G ( z ) ) ] (1) \min_G \max_D V(D,G)=E_{x\sim p_{data}(x)}[\log D(x)] + E_{z\sim p_{z}(z)}[\log D(G(z))] \tag{1}

GminDmaxV(D,G)=Expdata(x)[logD(x)]+Ezpz(z)[logD(G(z))](1)

訓練模型中需要用到的符號有:

  • G G G:生成器模型,通常為一個多層感知機結構的可微函式
  • D D D:判別器模型
  • x x x:判別器的輸入,包括真實資料樣本和生成器的輸出
  • z z z:生成器輸入的噪聲變數,則生成器的輸出為 x = G ( z ) x=G(z) x=G(z)
  • p d a t a ( x ) ≐ p ( x ∣ d a t a ) p_{data}(x) \doteq p(x|data) pdata(x)p(xdata):表示從實際資料集得到樣本 x x x 的概率
  • p z ( z ) p_{z}(z) pz(z):生成器輸入的噪聲變數 z z z 的先驗分佈
  • p g ( x ) ≐ p ( x ∣ g ) p_{g}(x) \doteq p(x|g) pg(x)p(xg):生成器輸出的樣本 x x x 的概率
  • p s r c ( d a t a ) p_{src}(data) psrc(data) p s r c ( g ) p_{src}(g) psrc(g):判別器模型輸入樣本中來自真實資料和來自生成器的概率,一般採用一半真實資料、一半假資料的方式,即: p s r c ( d a t a ) = p s r c ( g ) = 1 2 p_{src}(data)=p_{src}(g)=\frac{1}{2} psrc(data)=psrc(g)=21
  • G ( z ; θ g ) G(z;\theta_g) G(z;θg) θ g \theta_g θg 為生成器的多層感知機的引數, G ( z ; θ g ) G(z;\theta_g) G(z;θg) 代表生成器模型的輸出空間
  • D ( x ; θ d ) D(x;\theta_d) D(x;θd) θ d \theta_d θd 為判別器的多層感知機的引數, D ( x ; θ d ) D(x;\theta_d) D(x;θd) 為判別器的輸出,是一個標量值
  • D ( x ) D(x) D(x):判別器預測輸入樣本 x x x 來自於真實資料集的概率
  • ( G ∗ , D ∗ ) (G^*,D^*) (G,D):求得的解,即達到最終納什均衡點時的生成器和判別器

1.3.1 生成器 G G G 固定,尋求當下最優的判別器 D G ∗ D_G^* DG

判別器 D D D 實質上解決的是一個二分類問題,其損失函式可以用 負對數似然(Negative Log-Likelihood,NLL),也稱 絕對交叉熵損失(Categorical Cross-Entropy Loss) 來表示:

L ( D ) = − ∫ p ( x ) [ p ( d a t a ∣ x ) log ⁡ D ( x ) + p ( g ∣ x ) log ⁡ ( 1 − D ( x ) ) ] d x (2) L(D)=-\int p(x)[p(data|x) \log D(x) + p(g|x) \log (1-D(x))]dx \tag{2} L(D)=p(x)[p(datax)logD(x)+p(gx)log(1D(x))]dx(2)

其中:

  • p ( d a t a ∣ x ) p(data|x) p(datax):樣本 x x x 屬於真實資料集的概率
  • p ( g ∣ x ) p(g|x) p(gx):樣本 x x x 屬於生成器的概率

我們可以推出:

p ( x ) p ( d a t a ∣ x ) = p s r c ( d a t a ) p ( x ∣ d a t a ) = p s r c ( d a t a ) p d a t a ( x ) = 1 2 p d a t a ( x ) p(x)p(data|x)=p_{src}(data)p(x|data)=p_{src}(data)p_{data}(x)=\frac{1}{2}p_{data}(x) p(x)p(datax)=psrc(data)p(xdata)=psrc(data)pdata(x)=21pdata(x)

p ( x ) p ( g ∣ x ) = p s r c ( g ) p ( x ∣ g ) = p s r c ( g ) p g ( x ) = 1 2 p g ( x ) p(x)p(g|x)=p_{src}(g)p(x|g)=p_{src}(g)p_{g}(x)=\frac{1}{2}p_{g}(x) p(x)p(gx)=psrc(g)p(xg)=psrc(g)pg(x)=21pg(x)

代入公式 (2)則有:

L ( D ) = − ∫ p ( x ) [ p ( d a t a ∣ x ) log ⁡ D ( x ) + p ( g ∣ x ) log ⁡ ( 1 − D ( x ) ) ] d x = − ∫ [ 1 2 p d a t a ( x ) log ⁡ D ( x ) + 1 2 p g ( x ) log ⁡ ( 1 − D ( x ) ) ] d x = − 1 2 ( E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E x ∼ p g ( x ) [ log ⁡ ( 1 − D ( x ) ) ] ) (3) \begin{aligned} L(D) &= -\int p(x)[p(data|x) \log D(x) + p(g|x) \log (1-D(x))]dx \\ &= -\int [\frac{1}{2}p_{data}(x) \log D(x) + \frac{1}{2}p_{g}(x) \log (1-D(x))]dx \\ &= -\frac{1}{2}(E_{x\sim p_{data}(x)}[\log D(x)] + E_{x\sim p_{g}(x)}[\log (1-D(x))]) \tag{3} \end{aligned} L(D)=p(x)[p(datax)logD(x)+p(gx)log(1D(x))]dx=[21pdata(x)logD(x)+21pg(x)log(1D(x))]dx=21(Expdata(x)[logD(x)]+Expg(x)[log(1D(x))])(3)

因此,尋求當下最優的判別器 D G ∗ D_G^* DG 就是最大化以下值函式:

V ( D ) = E x ∼ p d a t a ( x ) [ log ⁡ D ( x ) ] + E x ∼ p g ( x ) [ log ⁡ ( 1 − D ( x ) ) ] (4) V(D)=E_{x\sim p_{data}(x)}[\log D(x)] + E_{x\sim p_{g}(x)}[\log (1-D(x))] \tag{4} V(D)=Expdata(x)[logD(x)]+Expg(x)[log(1D(x))](4)

對於單個樣本 x x x,則最大化值函式:

max ⁡ D p d a t a ( x ) log ⁡ D ( x ) + p g ( x ) [ log ⁡ ( 1 − D ( x ) ) ] (5) \max_D p_{data}(x)\log D(x) + p_{g}(x)[\log (1-D(x))] \tag{5} Dmaxpdata(x)logD(x)+pg(x)[log(1D(x))](5)

p d a t a ( x ) = a p_{data}(x)=a pdata(x)=a p g ( x ) = b p_{g}(x)=b pg(x)=b D ( x ) = D D(x)=D D(x)=D,則式(5)可以寫作:

f ( D ) = a D + b D f(D)= aD + bD f(D)=aD+bD

令其對 D D D 的導數為零有

d f ( D ) d D = a 1 D − b 1 1 − D = 0 \frac{df(D)}{dD}= a \frac{1}{D} - b \frac{1}{1-D}=0 dDdf(D)=aD1b1D1=0

從而有:

D ∗ = a a + b D^*=\frac{a}{a+b} D=a+ba

即:

D ∗ ( x ) = p d a t a ( x ) p d a t a ( x ) + p g ( x ) (6) D^*(x)=\frac{p_{data}(x)}{p_{data}(x)+p_{g}(x)} \tag{6} D(x)=pdata(x)+pg(x)pdata(x)(6)

在公式(6)外面套上對 x x x 的積分,解由單點變成函式解:

D G ∗ ( x ) = p d a t a p d a t a + p g (7) D^*_G(x)=\frac{p_{data}}{p_{data}+p_{g}} \tag{7} DG(x)=pdata+pgpdata(7)

將公式(7)代入公式(4)中,有:

V ( D G ∗ ( x ) ) = E x ∼ p d a t a ( x ) [ log ⁡ p d a t a p d a t a + p g ] + E x ∼ p g ( x ) [ log ⁡ ( 1 − p d a t a p d a t a + p g ) ] = ∫ x p d a t a ( x ) log ⁡ p d a t a p d a t a + p g d x + ∫ x p g ( x ) log ⁡ p g p d a t a + p g d x = ∫ x p d a t a ( x ) log ⁡ [ 1 2 × p d a t a ( p d a t a + p g ) / 2 ] d x + ∫ x p g ( x ) log ⁡ [ 1 2 × p g ( p d a t a + p g ) / 2 ] d x = − 2 log ⁡ 2 + ∫ x p d a t a ( x ) log ⁡ [ p d a t a ( p d a t a + p g ) / 2 ] d x + ∫ x p g ( x ) log ⁡ [ p g ( p d a t a + p g ) / 2 ] d x \begin{aligned} V(D^*_G(x)) &= E_{x\sim p_{data}(x)}[\log \frac{p_{data}}{p_{data}+p_{g}} ] + E_{x\sim p_{g}(x)}[\log (1-\frac{p_{data}}{p_{data}+p_{g}} )] \\ &= \int_x p_{data}(x) \log \frac{p_{data}}{p_{data}+p_{g}} dx + \int_x p_{g}(x) \log \frac{p_{g}}{p_{data}+p_{g}} dx \\ &= \int_x p_{data}(x) \log [\frac{1}{2} \times \frac{p_{data}}{(p_{data}+p_{g})/2}] dx + \int_x p_{g}(x) \log [\frac{1}{2} \times \frac{p_{g}}{(p_{data}+p_{g})/2}] dx \\ &= -2\log 2 + \int_x p_{data}(x) \log [\frac{p_{data}}{(p_{data}+p_{g})/2}] dx + \int_x p_{g}(x) \log [\frac{p_{g}}{(p_{data}+p_{g})/2}] dx \end{aligned} V(DG(x))=Expdata(x)[logpdata+pgpdata]+Expg(x)[log(1pdata+pgpdata)]=xpdata(x)logpdata+pgpdatadx+xpg(x)logpdata+pgpgdx=xpdata(x)log[21×(pdata+pg)/2pdata]dx+xpg(x)log[21×(pdata+pg)/2pg]dx=2log2+xpdata(x)log[(pdata+pg)/2pdata]dx+xpg(x)log[(pdata+pg)/2pg]dx

KL 散度(Kullback–Leibler Divergence)
又稱相對熵(Relative Entropy),兩個分佈 P P P Q Q Q 的 KL 散度記為 K L ( P ∥ Q ) KL(P\| Q) KL(PQ),計算公式為:
K L ( P ∥ Q ) = E x ∼ P ( x ) [ log ⁡ P ( x ) Q ( x ) ] = ∑ i = 1 n [ P ( x i ) log ⁡ P ( x i ) Q ( x i ) ] KL(P\| Q)=E_{x\sim P(x)}[\log\frac{P(x)}{Q(x)}]=\sum_{i=1}^n [P(x_i) \log\frac{P(x_i)}{Q(x_i)}] KL(PQ)=ExP(x)[logQ(x)P(x)]=i=1n[P(xi)logQ(xi)P(xi)]
JS 散度(Jensen–Shannon Divergence)
兩個分佈 P P P Q Q Q 的 JS 散度記為 J S ( P ∥ Q ) JS(P\| Q) JS(PQ),計算公式為:
J S ( P ∥ Q ) = 1 2 K L ( P ∥ P + Q 2 ) + 1 2 K L ( Q ∥ P + Q 2 ) JS(P\| Q)=\frac{1}{2}KL(P\| \frac{P+Q}{2}) + \frac{1}{2}KL(Q \| \frac{P+Q}{2}) JS(PQ)=21KL(P2P+Q)+21KL(Q2P+Q)

V ( D G ∗ ( x ) ) = − 2 log ⁡ 2 + ∫ x p d a t a ( x ) log ⁡ [ p d a t a ( p d a t a + p g ) / 2 ] d x + ∫ x p g ( x ) log ⁡ [ p g ( p d a t a + p g ) / 2 ] d x = − 2 log ⁡ 2 + K L ( p d a t a ∥ p d a t a + p g 2 ) + K L ( p g ∥ p d a t a + p g 2 ) = − 2 log ⁡ 2 + 2 J S ( p d a t a ∥ p g ) ∈ [ − 2 log ⁡ 2 , 0 ] \begin{aligned} V(D^*_G(x)) &= -2\log 2 + \int_x p_{data}(x) \log [\frac{p_{data}}{(p_{data}+p_{g})/2}] dx + \int_x p_{g}(x) \log [\frac{p_{g}}{(p_{data}+p_{g})/2}] dx \\ &= -2\log 2 + KL(p_{data} \| \frac{p_{data}+p_{g}}{2}) + KL(p_{g} \| \frac{p_{data}+p_{g}}{2}) \\ &= -2\log 2 + 2 JS(p_{data} \| p_{g}) \\ & \in [-2\log 2, 0] \end{aligned} V(DG(x))=2log2+xpdata(x)log[(pdata+pg)/2pdata]dx+xpg(x)log[(pdata+pg)/2pg]dx=2log2+KL(pdata2pdata+pg)+KL(pg2pdata+pg)=2log2+2JS(pdatapg)[2log2,0]

固定判別器為 D G ∗ D_G^* DG 時,求生成器 G G G 的值函式可以寫作:

min ⁡ G V ( G , D G ∗ ( x ) ) = min ⁡ G { − 2 log ⁡ 2 + 2 J S ( p d a t a ∥ p g ) } \min_G V(G, D^*_G(x))=\min_G\{ -2\log 2 + 2 JS(p_{data} \| p_{g}) \} GminV(G,DG(x))=Gmin{2log2+2JS(pdatapg)}

顯然,當 p d a t a = p g p_{data} = p_{g} pdata=pg 時, J S ( p d a t a ∥ p g ) = 0 JS(p_{data} \| p_{g})=0 JS(pdatapg)=0,最優解 G ∗ ( z ) = x ∼ p d a t a ( x ) G^*(z)=x \sim p_{data}(x) G(z)=xpdata(x) D ∗ ( x ) ≡ 1 2 D^*(x) \equiv \frac{1}{2} D(x)21,值函式 V ( G ∗ , D ∗ ) = − 2 l o g 2 V(G^*,D^*)=-2log2 V(G,D)=2log2

1.3.2 判別器 D D D 固定,尋求當下最優的判別器 G ∗ G^* G

G ′ G' G 為上一步的生成器, D D D 為在 G ′ G' G 下求得的最優判別器 D G ′ ∗ ( x ) D^*_{G'}(x) DG(x),那麼,求解最優 G ∗ G^* G 的過程為:

a r g min ⁡ G V ( G , D G ′ ∗ ) = a r g min ⁡ G [ K L ( p g ∥ p d a t a + p g ′ 2 ) − K L ( p g ∥ p g ′ ) ] arg \min_G V(G,D^*_{G'})=arg \min_G [KL(p_g \| \frac{p_{data}+p_{g'}}{2})-KL(p_g \| p_{g'})] argGminV(G,DG)=argGmin[KL(pg2pdata+pg)KL(pgpg)]

由此可以得出兩個結論:

  • 優化 G G G 的過程是讓 G G G 遠離前一步的 G ′ G' G,同時接近分佈 p d a t a + p g ′ 2 \frac{p_{data}+p_{g'}}{2} 2pdata+pg
  • 達到均衡點時 p g ′ = p d a t a p_{g'}=p_{data} pg=pdata,有 a r g min ⁡ G V ( G , D G ′ ∗ ) = a r g min ⁡ G 0 arg \min_G V(G,D^*_{G'})=arg \min_G 0 argminGV(G,DG)=argminG0,如果用這時的判別器去訓練一個新的生成器 G n e w G_{new} Gnew,理論上可能訓練不出來。

1.4 GANs 總結

(1)GANs 本質上式在最小化生成分佈與真實資料分佈的 JS 距離,當演算法收斂時生成器刻畫的分佈就是真實資料的分佈。

(2)發明 GANs 的初衷是為了更好地解決概率生成模型的估計問題

傳統概率生成模型方法(如:馬爾可夫隨機場、 貝葉斯網路)會涉及大量難以完成的概率推斷計算,而 GANs 可以避開這類計算。

如果隨機變數 Z Z Z X X X 之間滿足某種對映關係 X = f ( Z ) X=f(Z) X=f(Z),那麼它們的概率分佈 p X ( X ) p_X(X) pX(X) p Z ( Z ) p_Z(Z) pZ(Z) 也存在某種對映關係。當 Z , X ∈ R Z,X\in R Z,XR 都是一維隨機變數時, p X = d f ( Z ) d X p Z p_X=\frac{df(Z)}{dX}p_Z pX=dXdf(Z)pZ;當 Z , X Z,X Z,X 都是高維隨機變數時,導數變成雅克比矩陣 p X = J p Z p_X=Jp_Z pX=JpZ。 因此,已知 Z Z Z 的分佈,我們對隨機變數間的轉換函式 f f f 直接建模,就唯一確定了 X X X 的分佈。

這樣,不僅避開了大量複雜的概率計算,而且給 f f f 更大的發揮空間,我們可以用神經網路來訓練 f f f

1.5 GANs 存在的問題

在實際訓練中,早期階段生成器 G G G 很差,生成的模擬樣本很容易被判別器 D D D 識別,使得 D D D 回傳給 G G G 的梯度極小,達不到訓練的目的,這個現象稱為 優化飽和

原因分析

這裡將 D D D 的 Sigmoid 輸出層的前一層記為 o o o,那麼 D ( x ) D(x) D(x) 就可以表示成 D ( x ) = S i g m o i d ( o ( x ) ) D(x)=Sigmoid(o(x)) D(x)=Sigmoid(o(x)),此時有:

∇ D ( x ) = ∇ S i g m o i d ( o ( x ) ) = D ( x ) ( 1 − D ( x ) ) ∇ o ( x ) \nabla D(x)= \nabla Sigmoid(o(x)) = D(x)(1-D(x))\nabla o(x) D(x)=Sigmoid(o(x))=D(x)(1D(x))o(x)

因此訓練 G G G 的梯度為:

∇ log ⁡ ( 1 − D ( G ( z ; θ g ) ) ) = − D ( G ( z ; θ g ) ) ∇ o ( G ( z ; θ g ) ) \nabla \log(1-D(G(z;\theta_g))) = -D(G(z;\theta_g))\nabla o(G(z;\theta_g)) log(1D(G(z;θg)))=D(G(z;θg))o(G(z;θg))

D D D 能很好的分類樣本時,意味著認錯假樣本的概率幾乎為零,即 D ( G ( z ; θ g ) ) → 0 D(G(z;\theta_g)) \rightarrow 0 D(G(z;θg))0,假定 ∣ o ( G ( z ; θ g ) ) ∣ < C |o(G(z;\theta_g))|<C o(G(z;θg))<C C C C 為一個常數),則可推出:

lim ⁡ D ( G ( z ; θ g ) ) → 0 ∇ log ⁡ ( 1 − D ( G ( z ; θ g ) ) ) = − lim ⁡ D ( G ( z ; θ g ) ) → 0 D ( G ( z ; θ g ) ) ∇ o ( G ( z ; θ g ) ) = 0 \lim_{D(G(z;\theta_g)) \rightarrow 0} \nabla \log(1-D(G(z;\theta_g)))=-\lim_{D(G(z;\theta_g)) \rightarrow 0} D(G(z;\theta_g))\nabla o(G(z;\theta_g))=0 D(G(z;θg))0limlog(1D(G(z;θg)))=D(G(z;θg))0limD(G(z;θg))o(G(z;θg))=0

G G G 獲得的梯度基本為零,因此 D D D 強大後對 G G G 的幫助反而很小。

解決方法

解決方案是將 log ⁡ ( 1 − D ( G ( z ; θ g ) ) ) \log(1-D(G(z;\theta_g))) log(1D(G(z;θg))) 變為 log ⁡ ( D ( G ( z ; θ g ) ) \log(D(G(z;\theta_g)) log(D(G(z;θg)),形式上有一個負號的差別,故讓後者最大等效於讓前者最小,二者在最優解相同。

更改後的目標函式的梯度為:

log ⁡ ( D ( G ( z ; θ g ) ) ) = ( 1 − D ( G ( z ; θ g ) ) ) ∇ o ( G ( z ; θ g ) ) lim ⁡ D ( G ( z ; θ g ) ) → 0 ∇ log ⁡ ( D ( G ( z ; θ g ) ) ) = ∇ o ( G ( z ; θ g ) ) \begin{aligned} \log(D(G(z;\theta_g))) &= (1-D(G(z;\theta_g))) \nabla o(G(z;\theta_g)) \\ \lim_{D(G(z;\theta_g)) \rightarrow 0} \nabla \log(D(G(z;\theta_g))) &= \nabla o(G(z;\theta_g)) \end{aligned} log(D(G(z;θg)))D(G(z;θg))0limlog(D(G(z;θg)))=(1D(G(z;θg)))o(G(z;θg))=o(G(z;θg))

因此,更改後即使 D ( G ( z ; θ g ) ) → 0 D(G(z;\theta_g)) \rightarrow 0 D(G(z;θg))0 ∇ log ⁡ ( D ( G ( z ; θ g ) ) ) \nabla \log(D(G(z;\theta_g))) log(D(G(z;θg))) 也不會消失,仍能給生成器提供有效的梯度。

(GAN 的變種演算法以後再繼續補充)

2 WGAN

原始 GAN 的判別器是最小化生成分佈與真實資料分佈的 JS 距離,WGAN演算法的改進在於它使用的是 Wasserstein 距離,也稱 推土機距離(Earth Mover Distance)

W ( P , Q ) = inf ⁡ γ ∼ ∏ ( P , Q ) E ( x , y ) ∼ γ [ ∥ x − y ∥ ] W(P,Q)=\inf_{\gamma \sim \prod(P,Q)} E_{(x,y) \sim \gamma}[\|x-y\|] W(P,Q)=γ(P,Q)infE(x,y)γ[xy]

3 DCGAN

使用卷積神經網路的GAN。

4 ALI(Adversarially Learned Inference)

將生成網路和推斷網路一起放到 GANs 的框架下,進而聯合訓練生成模型和推斷模型。

5 IRGAN(Information Retrieval GAN)

利用 GANs 框架生成離散樣本資料

6 SeqGAN(Sequence GAN)

利用 GANs 框架生成文字序列