1. 程式人生 > 其它 >Tensorflow2.0實戰之GAN

Tensorflow2.0實戰之GAN

本文主要帶領讀者瞭解生成對抗神經網路(GAN),並使用提供的face資料集訓練網路

GAN 入門

自 2014 年 Ian Goodfellow 的《生成對抗網路(Generative Adversarial Networks)》論文發表以來,GAN 的進展突飛猛進,生成結果也越來越具有照片真實感。
就在三年前,Ian Goodfellow 在 reddit 上回答 GAN 是否可以應用在文字領域的問題時,還認為 GAN 不能擴充套件到文字領域。

“由於 GAN 定義在實值資料上,因此 GAN 不能應用於 NLP。
GAN 的工作原理是訓練一個生成網路,輸出合成數據,然後利用判別網路判別合成數據。判別網路根據合成數據輸出的梯度告訴你該如何對合成資料進行微調,使其更真實。
因此只有當合成資料是基於連續數字時,才能對其進行微調。如果是基於離散的數字,就沒有辦法做微小的改變。
例如,如果輸出畫素值為 1.0 的影象,則下一步可以將該畫素值更改為 1.0001。
但如果輸出單詞‘penguin’,不能在下一步直接將其更改為‘penguin+.001’,因為沒有‘penguin+.001’這樣的單詞。你必須從‘penguin’直接轉變到‘ostrich’。
由於所有的 NLP 都是基於離散的值,如單詞、字元或位元組,所以目前還沒有人知道該如何將 GAN 應用於 NLP。”

但是現在,GAN 已經可用於生成各種內容,包括影象、視訊、音訊和文字。這些輸出的合成數據既可以用於訓練其他的模型,也可以用於建立一些有趣的專案。

GAN 原理

GAN 由兩個神經網路組成,一個是合成新樣本的生成器,另一個是對比訓練樣本與生成樣本的判別器。判別器的目標是區分“真實”和“虛假”的輸入(對樣本來自模型分佈還是真實分佈進行分類)。這些樣本可以是影象、視訊、音訊片段和文字。

為了合成這些新的樣本,生成器的輸入為隨機噪聲,然後嘗試從訓練資料中學習到的分佈中生成真實的影象。
判別器網路(卷積神經網路)輸出相對於合成數據的梯度,其中包含著如何改變合成數據以使其更具真實感的資訊。最終生成器收斂,它可以生成符合真實資料分佈的樣本,而判別器無法區分生成資料和真實資料。
ok,接下來我們就來實現一下

準備階段

下載資料集
資料集,筆者這裡已經為大家提供了,連結如下:
連結: https://pan.baidu.com/s/15wFZAANvr8gajiVY_1mI0A
提取碼: c9vy
解壓資料集
將下載好的資料集解壓,放在工程目錄下

載入資料集
載入資料集的程式碼,筆者這裡直接提供給大家了,下面只是展示部分程式碼,文末會提供完整專案的程式碼連結

import multiprocessing
import tensorflow as tf
def make_anime_dataset(img_paths, batch_size, resize=64, drop_remainder=True, shuffle=True, repeat=1):
    @tf.function
    def _map_fn(img):
        img = tf.image.resize(img, [resize, resize])
        img = tf.clip_by_value(img, 0, 255)
        img = img / 127.5 - 1
        return img
    dataset = disk_image_batch_dataset(img_paths,
                                          batch_size,
                                          drop_remainder=drop_remainder,
                                          map_fn=_map_fn,
                                          shuffle=shuffle,
                                          repeat=repeat)
    img_shape = (resize, resize, 3)
    len_dataset = len(img_paths) // batch_size
    return dataset, img_shape, len_dataset
def batch_dataset(dataset,
                  batch_size,
                  drop_remainder=True,
                  n_prefetch_batch=1,
                  filter_fn=None,
                  map_fn=None,
                  n_map_threads=None,
                  filter_after_map=False,
                  shuffle=True,
                  shuffle_buffer_size=None,
                  repeat=None):

構建網路
搭建Generator,Generator包含兩個部分,init部分和前向傳播的call部分,程式碼如下

class Generator(keras.Model):
    def __init__(self):
        super(Generator, self).__init__()
        # z:[b,100]-->[b,3*3*512]-->[b,3,3,512]-->[b,64,64,3]
        self.fc=keras.layers.Dense(3*3*512)

        self.conv1=keras.layers.Conv2DTranspose(256,3,3,'valid')  # 反捲積
        self.bn1=keras.layers.BatchNormalization()

        self.conv2=keras.layers.Conv2DTranspose(128,5,2,'valid')
        self.bn2=keras.layers.BatchNormalization()

        self.conv3=keras.layers.Conv2DTranspose(3,4,3,'valid')

    def call(self, inputs, training=None, mask=None):
        # [z,100]-->[z,3*3*512]
        x=self.fc(inputs)
        x=tf.reshape(x,[-1,3,3,512])
        x=tf.nn.leaky_relu(x)

        x=tf.nn.leaky_relu(self.bn1(self.conv1(x),training=training))
        x=tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))
        x=self.conv3(x)
        x=tf.tanh(x)
        return x

搭建Discriminator,同上

class Discriminator(keras.Model):
    def __init__(self):
        super(Discriminator, self).__init__()
        # [b,64,64,3]-->[b,1]
        self.conv1=keras.layers.Conv2D(64,5,3,'valid')

        self.conv2=keras.layers.Conv2D(128,5,3,'valid')
        self.bn2=keras.layers.BatchNormalization()

        self.conv3=keras.layers.Conv2D(256,5,3,'valid')
        self.bn3=keras.layers.BatchNormalization()

        # [b,h,w,c]-->[b,-1]
        self.flatten=keras.layers.Flatten()
        # [b,-1]-->[b,1]
        self.fc=keras.layers.Dense(1)
    def call(self, inputs, training=None, mask=None):
        x=tf.nn.leaky_relu(self.conv1(inputs))
        x=tf.nn.leaky_relu(self.bn2(self.conv2(x),training=training))
        x=tf.nn.leaky_relu(self.bn3(self.conv3(x),training=training))
        x=self.flatten(x)
        logits=self.fc(x)
        return logits

訓練GAN
定義相關資料,包括epoch,lr等等
這些資料可以自定義,筆者這裡就不改動了

	 z_dim = 100
    epochs = 50000
    batch_size = 512
    learning_rate = 0.0002
    is_training = True

載入資料

	img_path=glob.glob(r'E:\python_pro\TF2.0\GAN\faces\*.jpg')
    dataset, img_shape, _ = make_anime_dataset(img_path, batch_size)

可以列印檢視資料集資訊:

(512, 64, 64, 3), (64, 64, 3)
(512, 64, 64, 3) ,1.0, -1.0

定義優化器,注意我們在開始訓練時,需要新建訓練GAN圖片的檔案,為檢視資料提供持久化依據

    for epoch in range(epochs):

        batch_z = tf.random.uniform([batch_size, z_dim], minval=-1., maxval=1.)
        batch_x = next(db_iter)

        # train D
        with tf.GradientTape() as tape:
            d_loss = d_loss_fn(generator, discriminator, batch_z, batch_x, is_training)
        grads = tape.gradient(d_loss, discriminator.trainable_variables)
        d_optimizer.apply_gradients(zip(grads, discriminator.trainable_variables))


        with tf.GradientTape() as tape:
            g_loss = g_loss_fn(generator, discriminator, batch_z, is_training)
        grads = tape.gradient(g_loss, generator.trainable_variables)
        g_optimizer.apply_gradients(zip(grads, generator.trainable_variables))

        if epoch % 100 == 0:
            print(epoch, 'd-loss:',float(d_loss), 'g-loss:', float(g_loss))

            z = tf.random.uniform([100, z_dim])
            fake_image = generator(z, training=False)
            img_path = os.path.join('GAN_IMAGE', 'gan%d.png'%epoch)
            save_result(fake_image.numpy(), 10, img_path, color_mode='P')

訓練結果

接下來我們來看看,訓練的效果圖,注意,GAN的訓練過程是非常非常非常慢的,大概訓練十幾個小時,才能有個比較好的效果,有的資料集甚至會訓練幾天之久,這個隨資料集的大小和對最終效果的要求來定的。筆者這個資料集比較的簡單,只是給大家做演示,好了,廢話就不過多的說了,上圖




上述分別是訓練了100epoch、500、1500、4000的效果圖,可以看到隨著訓練的次數增加,效果因為越來越好了

總結

大家在訓練GAN時,還是需要一個好一些的GPU顯示卡才行,這樣可以體驗GPU給我們帶來的加速效果。這樣會使得訓練的速度大大加快。
筆者水平有限,如有表述不準確的地方還請諒解,有錯誤的地方歡迎大家批評指正。
最後還是希望大家動手實踐實踐,共同進步。
最終的程式碼連結:https://github.com/huzixuan1/TF_2.0/tree/master/GAN