1. 程式人生 > >不要慫,就是GAN (生成式對抗網路) (六):Wasserstein GAN(WGAN) TensorFlow 程式碼

不要慫,就是GAN (生成式對抗網路) (六):Wasserstein GAN(WGAN) TensorFlow 程式碼

先來梳理一下我們之前所寫的程式碼,原始的生成對抗網路,所要優化的目標函式為:

 此目標函式可以分為兩部分來看:

①固定生成器 G,優化判別器 D, 則上式可以寫成如下形式: 

 

可以轉化為最小化形式: 

我們編寫的程式碼中,d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = D_logits, labels = tf.ones_like(D))),由於我們判別器最後一層是 sigmoid ,所以可以看出來 d_loss_real 是上式中的第一項(捨去常數概率 1/2),d_loss_fake 為上式中的第二項。

②固定判別器 D,優化生成器 G,捨去前面的常數,相當於最小化:

也相當於最小化:

我們的程式碼中,g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits = D_logits_, labels = tf.ones_like(D))),完美對應上式。

接下來開始我們的 WGAN 之旅,正如 https://zhuanlan.zhihu.com/p/25071913 所介紹的,我們要構建一個判別器 D,使得 D 的引數不超過某個固定的常數,最後一層是非線性層,並且使式子:

達到最大,那麼 L 就可以作為我們的 Wasserstein 距離,生成器的目標是最小化這個距離,去掉第一項與生成器無關的項,得到我們生成器的損失函式。我們可以把上式加個負號,作為 D 的損失函式,其中加負號後的第一項,是 d_loss_real,加負號後的第二項,是 d_loss_fake。

下面開始碼程式碼:

為了方便,我們直接在上一節我們的 none_cond_DCGAN.py 檔案中修改相應的程式碼:

在開頭的巨集定義中加入:

CLIP = [-0.01, 0.01]
CRITIC_NUM = 5

 如圖:

註釋掉原來 discriminator 的 return,重新輸入一個 return 如下:

在 train 函式裡面,修改如下地方:

在迴圈裡面,要改如下地方,這裡稍微做一下說明,idx < 25 時 D 迴圈更新 25 次才會更新 G,用來保證 D 的網路大致滿足 Wasserstein 距離,這是一個小小的 trick。

改完之後點選執行進行訓練,WGAN 收斂速度很快,大約一千多次迭代的時候,生成網路生成的影象已經很像了,最後生成的影象如下,可以看到,影象還是有些噪點和壞點的。

最後的最後,貼一張網路的 Graph:

參考文獻:

1. https://zhuanlan.zhihu.com/p/25071913