1. 程式人生 > >4 李巨集毅生成對抗網路學習———WGAN

4 李巨集毅生成對抗網路學習———WGAN

論文題目:《Wasserstein GAN》( 2017年收錄於ICML )

1 背景

生成對抗網路GAN( General Adversarial Networks )誕生以來出現的 一些問題。問題核心在GAN的目標函式JS散度

  1. 當判別器D訓練得越好,生成器G的梯度消失越嚴重。

    優化JS散度,相當於兩個分佈更加接近,讓兩個分佈重疊的時候,可以達到“以假亂 真”。但是經過分析,兩種分佈重疊的概率很小。 也就是說,無論兩種分佈相距多遠,JS 散度為一個常數,因此導致生成器的梯度近似為零,梯度消失。

  2. 梯度不穩定,多樣性不足
    在這裡插入圖片描述
    一是同時最小化生成分佈與真實分佈的KL散度,最大化JS散度,產生矛盾,造成 梯度不穩定;二是KL散度不是一個對稱的衡量,通過分析可知,KL(Pg||Pr) 和 KL(Pr||Pg) 對不同的懲罰不一樣,造成生成器G生成一些重複但是“安全”的樣本,不願生成多樣性 的樣本。

2 WGAN

提出Wasserstein 距離代替JS散度
在這裡插入圖片描述
在這裡插入圖片描述

3 程式碼實現

計算損失:

real_x = tf.placeholder(tf.float32, shape=[batch_size, mnist_dim])
random_x = tf.placeholder(tf.float32, shape=[batch_size, random_dim])
random_y = Generator(random_x)

eps = tf.random_uniform([batch_size, 1], minval=0., maxval=1.)#
inter_x = eps * real_x + (1. - eps) * random_y
grad = tf.gradients(Discriminator(inter_x), [inter_x])[0]
grad_norm = tf.sqrt(tf.reduce_sum((grad)**2,axis = 1))
grad_pen = 10 *  tf.reduce_mean(tf.nn.relu(grad_norm - 1.))

D_loss = tf.reduce_mean(Discriminator(random_y)) - tf.reduce_mean(Discriminator(real_x)) + grad_pen
G_loss = -tf.reduce_mean(Discriminator(random_y))

其中

tf.reduce_mean
reduce_mean(
    input_tensor,
    axis=None,
    keep_dims=False,
    name=None,
    reduction_indices=None
)
#!/usr/bin/python

import tensorflow as tf
import numpy as np

initial = [[1.,1.],[2.,2.]]
x = tf.Variable(initial,dtype=tf.float32)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
    sess.run(init_op)
    print(sess.run(tf.reduce_mean(x)))
    print(sess.run(tf.reduce_mean(x,0))) #Column
    print(sess.run(tf.reduce_mean(x,1))) #row

ref
https://www.alexirpan.com/2017/02/22/wasserstein-gan.html
https://vincentherrmann.github.io/blog/wasserstein/
https://zhuanlan.zhihu.com/p/25071913