深度有趣 | 16 令人拍案叫絕的WGAN
簡介
在DCGAN的基礎上,介紹WGAN的原理和實現,並在LFW和CelebA兩個資料集上進一步實踐
問題
GAN一直面臨以下問題和挑戰
- 訓練困難,需要精心設計模型結構,並小心協調G和D的訓練程度
- G和D的損失函式無法指示訓練過程,缺乏一個有意義的指標和生成圖片的質量相關聯
- 模式崩壞(mode collapse),生成的圖片雖然看起來像是真的,但是缺乏多樣性
原理
相對於傳統的GAN,WGAN只做了以下三點簡單的改動
- D最後一層去掉sigmoid
- G和D的loss不取log(
sigmoid_cross_entropy_with_logits
) - 每次更新D的引數之後,將其絕對值截斷到不超過一個固定常數c,即gradient clipping(前作);或使用梯度懲罰,即gradient penalty(後作)
G的損失函式原本為
其導致的結果是,如果D訓練得太好,G將學習不到有效的梯度
但是,如果D訓練得不夠好,G也學習不到有效的梯度
就像警察如果太厲害,便直接把小偷幹掉了;但警察如果不厲害,就無法迫使小偷變得更厲害
因此以上損失函式導致GAN訓練特別不穩定,需要小心協調G和D的訓練程度
GAN的作者提出了G損失函式的另一個版本,即所謂的-logD trick
G需要最小化以上損失函式,等價於最小化以下損失函式
其中前者為KL散度(Kullback–Leibler Divergence)
後者為JS散度(Jensen-Shannon Divergence)
兩者都可以用於衡量兩個分佈之間的距離,越小說明兩個分佈越相似
因此以上損失函式,一方面要減小KL散度,另一方面卻要增大JS散度,一邊拉近一邊推遠,從而導致訓練不穩定
除此之外,KL散度的不對稱性,導致對以下兩種情況的不同懲罰
- G生成了不真實的圖片,即缺乏準確性,懲罰較高
- G生成了和真實圖片類似的圖片,即缺乏多樣性,懲罰較低
從而導致,G傾向於生成一些有把握但相似的圖片,而不敢輕易地嘗試去生成沒把握的新圖片,即所謂的mode collapse問題
WGAN所做的三點改動,解決了GAN訓練困難和不穩定、mode collapse等問題,而且G的損失函式越小,對應生成的圖片質量就越高
WGAN訓練過程如下,gradient penalty使得D滿足1-Lipschitz連續條件,詳細原理和細節可以閱讀相關論文進一步瞭解
論文中部分實驗結果如下,WGAN雖然需要更長的訓練時間,但收斂更加穩定
更重要的是,WGAN提供了一種更穩定的GAN框架。DCGAN中的G去掉Batch Normalization就會崩掉,但WGAN則沒有這種限制
如果用Deep Convolutional結構實現WGAN,那麼其結果和DCGAN差不多。但是在WGAN的框架下,可以用更深更復雜的網路實現G和D,例如ResNet(https://arxiv.org/abs/1512.03385),從而達到更好的生成效果
資料
還是之前使用過的兩個人臉資料集
- CelebA:http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html,CelebFaces Attributes Dataset,包括10177人共計超過20W張圖片,並且每張圖片還包括人臉的5個關鍵點位置和40個屬性的01標註,例如是否有眼鏡、帽子、鬍子等
實現
載入庫
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import os
import matplotlib.pyplot as plt
%matplotlib inline
from imageio import imread, imsave, mimsave
import cv2
import glob
from tqdm import tqdm
選擇資料集
dataset = 'lfw_new_imgs' # LFW
# dataset = 'celeba' # CelebA
images = glob.glob(os.path.join(dataset, '*.*'))
print(len(images))
定義一些常量、網路輸入、輔助函式
batch_size = 100
z_dim = 100
WIDTH = 64
HEIGHT = 64
LAMBDA = 10
DIS_ITERS = 3 # 5
OUTPUT_DIR = 'samples_' + dataset
if not os.path.exists(OUTPUT_DIR):
os.mkdir(OUTPUT_DIR)
X = tf.placeholder(dtype=tf.float32, shape=[batch_size, HEIGHT, WIDTH, 3], name='X')
noise = tf.placeholder(dtype=tf.float32, shape=[batch_size, z_dim], name='noise')
is_training = tf.placeholder(dtype=tf.bool, name='is_training')
def lrelu(x, leak=0.2):
return tf.maximum(x, leak * x)
判別器部分,注意需要去掉Batch Normalization,否則會導致batch之間的相關性,從而影響gradient penalty的計算
def discriminator(image, reuse=None, is_training=is_training):
momentum = 0.9
with tf.variable_scope('discriminator', reuse=reuse):
h0 = lrelu(tf.layers.conv2d(image, kernel_size=5, filters=64, strides=2, padding='same'))
h1 = lrelu(tf.layers.conv2d(h0, kernel_size=5, filters=128, strides=2, padding='same'))
h2 = lrelu(tf.layers.conv2d(h1, kernel_size=5, filters=256, strides=2, padding='same'))
h3 = lrelu(tf.layers.conv2d(h2, kernel_size=5, filters=512, strides=2, padding='same'))
h4 = tf.contrib.layers.flatten(h3)
h4 = tf.layers.dense(h4, units=1)
return h4
生成器部分
def generator(z, is_training=is_training):
momentum = 0.9
with tf.variable_scope('generator', reuse=None):
d = 4
h0 = tf.layers.dense(z, units=d * d * 512)
h0 = tf.reshape(h0, shape=[-1, d, d, 512])
h0 = tf.nn.relu(tf.contrib.layers.batch_norm(h0, is_training=is_training, decay=momentum))
h1 = tf.layers.conv2d_transpose(h0, kernel_size=5, filters=256, strides=2, padding='same')
h1 = tf.nn.relu(tf.contrib.layers.batch_norm(h1, is_training=is_training, decay=momentum))
h2 = tf.layers.conv2d_transpose(h1, kernel_size=5, filters=128, strides=2, padding='same')
h2 = tf.nn.relu(tf.contrib.layers.batch_norm(h2, is_training=is_training, decay=momentum))
h3 = tf.layers.conv2d_transpose(h2, kernel_size=5, filters=64, strides=2, padding='same')
h3 = tf.nn.relu(tf.contrib.layers.batch_norm(h3, is_training=is_training, decay=momentum))
h4 = tf.layers.conv2d_transpose(h3, kernel_size=5, filters=3, strides=2, padding='same', activation=tf.nn.tanh, name='g')
return h4
損失函式
g = generator(noise)
d_real = discriminator(X)
d_fake = discriminator(g, reuse=True)
loss_d_real = -tf.reduce_mean(d_real)
loss_d_fake = tf.reduce_mean(d_fake)
loss_g = -tf.reduce_mean(d_fake)
loss_d = loss_d_real + loss_d_fake
alpha = tf.random_uniform(shape=[batch_size, 1, 1, 1], minval=0., maxval=1.)
interpolates = alpha * X + (1 - alpha) * g
grad = tf.gradients(discriminator(interpolates, reuse=True), [interpolates])[0]
slop = tf.sqrt(tf.reduce_sum(tf.square(grad), axis=[1]))
gp = tf.reduce_mean((slop - 1.) ** 2)
loss_d += LAMBDA * gp
vars_g = [var for var in tf.trainable_variables() if var.name.startswith('generator')]
vars_d = [var for var in tf.trainable_variables() if var.name.startswith('discriminator')]
優化函式
update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
with tf.control_dependencies(update_ops):
optimizer_d = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_d, var_list=vars_d)
optimizer_g = tf.train.AdamOptimizer(learning_rate=0.0002, beta1=0.5).minimize(loss_g, var_list=vars_g)
讀取圖片的函式
def read_image(path, height, width):
image = imread(path)
h = image.shape[0]
w = image.shape[1]
if h > w:
image = image[h // 2 - w // 2: h // 2 + w // 2, :, :]
else:
image = image[:, w // 2 - h // 2: w // 2 + h // 2, :]
image = cv2.resize(image, (width, height))
return image / 255.
合成圖片的函式
def montage(images):
if isinstance(images, list):
images = np.array(images)
img_h = images.shape[1]
img_w = images.shape[2]
n_plots = int(np.ceil(np.sqrt(images.shape[0])))
if len(images.shape) == 4 and images.shape[3] == 3:
m = np.ones(
(images.shape[1] * n_plots + n_plots + 1,
images.shape[2] * n_plots + n_plots + 1, 3)) * 0.5
elif len(images.shape) == 4 and images.shape[3] == 1:
m = np.ones(
(images.shape[1] * n_plots + n_plots + 1,
images.shape[2] * n_plots + n_plots + 1, 1)) * 0.5
elif len(images.shape) == 3:
m = np.ones(
(images.shape[1] * n_plots + n_plots + 1,
images.shape[2] * n_plots + n_plots + 1)) * 0.5
else:
raise ValueError('Could not parse image shape of {}'.format(images.shape))
for i in range(n_plots):
for j in range(n_plots):
this_filter = i * n_plots + j
if this_filter < images.shape[0]:
this_img = images[this_filter]
m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
return m
隨機產生批資料的函式
def get_random_batch(nums):
img_index = np.arange(len(images))
np.random.shuffle(img_index)
img_index = img_index[:nums]
batch = np.array([read_image(images[i], HEIGHT, WIDTH) for i in img_index])
batch = (batch - 0.5) * 2
return batch
模型的訓練
sess = tf.Session()
sess.run(tf.global_variables_initializer())
z_samples = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
samples = []
loss = {'d': [], 'g': []}
for i in tqdm(range(60000)):
for j in range(DIS_ITERS):
n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
batch = get_random_batch(batch_size)
_, d_ls = sess.run([optimizer_d, loss_d], feed_dict={X: batch, noise: n, is_training: True})
_, g_ls = sess.run([optimizer_g, loss_g], feed_dict={X: batch, noise: n, is_training: True})
loss['d'].append(d_ls)
loss['g'].append(g_ls)
if i % 500 == 0:
print(i, d_ls, g_ls)
gen_imgs = sess.run(g, feed_dict={noise: z_samples, is_training: False})
gen_imgs = (gen_imgs + 1) / 2
imgs = [img[:, :, :] for img in gen_imgs]
gen_imgs = montage(imgs)
plt.axis('off')
plt.imshow(gen_imgs)
imsave(os.path.join(OUTPUT_DIR, 'sample_%d.jpg' % i), gen_imgs)
plt.show()
samples.append(gen_imgs)
plt.plot(loss['d'], label='Discriminator')
plt.plot(loss['g'], label='Generator')
plt.legend(loc='upper right')
plt.savefig(os.path.join(OUTPUT_DIR, 'Loss.png'))
plt.show()
mimsave(os.path.join(OUTPUT_DIR, 'samples.gif'), samples, fps=10)
LFW人臉生成結果如下,和DCGAN相比更加穩定
CelebA人臉生成結果如下
儲存模型,便於後續使用
saver = tf.train.Saver()
saver.save(sess, os.path.join(OUTPUT_DIR, 'wgan_' + dataset), global_step=60000)
在單機上使用模型生成人臉圖片
# -*- coding: utf-8 -*-
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
batch_size = 100
z_dim = 100
# dataset = 'lfw_new_imgs'
dataset = 'celeba'
def montage(images):
if isinstance(images, list):
images = np.array(images)
img_h = images.shape[1]
img_w = images.shape[2]
n_plots = int(np.ceil(np.sqrt(images.shape[0])))
if len(images.shape) == 4 and images.shape[3] == 3:
m = np.ones(
(images.shape[1] * n_plots + n_plots + 1,
images.shape[2] * n_plots + n_plots + 1, 3)) * 0.5
elif len(images.shape) == 4 and images.shape[3] == 1:
m = np.ones(
(images.shape[1] * n_plots + n_plots + 1,
images.shape[2] * n_plots + n_plots + 1, 1)) * 0.5
elif len(images.shape) == 3:
m = np.ones(
(images.shape[1] * n_plots + n_plots + 1,
images.shape[2] * n_plots + n_plots + 1)) * 0.5
else:
raise ValueError('Could not parse image shape of {}'.format(images.shape))
for i in range(n_plots):
for j in range(n_plots):
this_filter = i * n_plots + j
if this_filter < images.shape[0]:
this_img = images[this_filter]
m[1 + i + i * img_h:1 + i + (i + 1) * img_h,
1 + j + j * img_w:1 + j + (j + 1) * img_w] = this_img
return m
sess = tf.Session()
sess.run(tf.global_variables_initializer())
saver = tf.train.import_meta_graph(os.path.join('samples_' + dataset, 'wgan_' + dataset + '-60000.meta'))
saver.restore(sess, tf.train.latest_checkpoint('samples_' + dataset))
graph = tf.get_default_graph()
g = graph.get_tensor_by_name('generator/g/Tanh:0')
noise = graph.get_tensor_by_name('noise:0')
is_training = graph.get_tensor_by_name('is_training:0')
n = np.random.uniform(-1.0, 1.0, [batch_size, z_dim]).astype(np.float32)
gen_imgs = sess.run(g, feed_dict={noise: n, is_training: False})
gen_imgs = (gen_imgs + 1) / 2
imgs = [img[:, :, :] for img in gen_imgs]
gen_imgs = montage(imgs)
gen_imgs = np.clip(gen_imgs, 0, 1)
plt.figure(figsize=(8, 8))
plt.axis('off')
plt.imshow(gen_imgs)
plt.show()