1. 程式人生 > 其它 >生成對抗網路(GAN)系列(一)

生成對抗網路(GAN)系列(一)

生成式模型的作用

密度估計

給定一組資料\(D=\left \{ x^{n} \right \}^{N}_{n=1}\),假設它們都是獨立地從相同的概率密度函式為\(p_{r}(x)\)的未知分佈中產生的。密度估計是根據資料集\(D\)來估計其概率密度函式\(p_{\theta}(x)\)
在機器學習中,密度估計是一類無監督學習問題。比如在手寫體數字影象的密度估計問題中,我們將影象表示為一個隨機變數\(X\),其中每一維都表示一個畫素值。假設手寫體數字影象都服從一個未知的分佈\(p_{r}{x}\),希望通過一些觀測樣本來估計其分佈。但是手寫體數字影象中不同畫素之間存在複雜的依賴關係,很難用一個明確的圖模型來描述其依賴關係,所以直接建模\(p_{r}{x}\)

,比較困難。因此,我們通過引入隱變數\(z\)來簡化模型,這樣密度估計問題可以轉化為估計變數(x, z)的兩個區域性條件概率\(p_{\theta}(z)\)\(p_{\theta}(x|z)\)。一般為了簡化模型,假設隱變數\(z\)的先驗分佈為標準高斯分佈\(N(0, I)\)。隱變數\(z\)的每一維度之間都是獨立的,密度估計的重點是估計條件分佈\(p(x|z; \theta)\)
如果要建模隱含變數的分佈,就需要用EM演算法來進行密度估計,而在EM演算法中,需要估計條件分佈\(p(x|z; \theta)\)以及後驗概率分佈\(p(z|x; \theta)\)。當這兩個分佈比較複雜時,就可以利用神經網路來建模(如變分自編碼器)。

生成樣本

在知道\(p_{\theta}(z)\)和得到\(p_{\theta}(x|z)\)之後就可以生成新的資料:

  • 從隱變數的先驗分佈\(p_{\theta}(z)\)中取樣,得到樣本\(z\)
  • 根據條件概率分佈\(p_{\theta}(x|z)\)進行取樣,得到新的樣本\(x\)

生成對抗網路

本文的重點是生成對抗網路(GAN)。與一般的生成式模型(如VAE、DQN)不同,GAN並不直接建模\(p(x)\),而是直接通過一個神經網路學習從隱變數\(z\)到資料\(x\)的對映,稱為生成器;然後將生成的樣本交給判別網路判斷是否是真實的樣本。可以看出,生成網路和判別網路的訓練是彼此依存、交替進行的。

生成對抗網路流程圖

判別網路

判別網路\(D(\boldsymbol x;\phi )\)的目標是區分出一個樣本\(\boldsymbol x\)是來自於真實分佈\(p_{r}(\boldsymbol x)\)還是來自於生成模型\(p_{\theta}(\boldsymbol x)\)。由此可見,判別網路實際上是一個二分類的分類器。用標籤\(y=1\)來表示樣本來自真實分佈,\(y=0\)表示樣本來自生成模型,判別網路\(D(\boldsymbol x;\phi )\)的輸出為\(\boldsymbol x\)屬於真實資料分佈的概率,即:

\[p(y=1 | x) = D(\boldsymbol x;\phi ) \]

樣本來自生成模型的概率為:

\[p(y=0 | x) = 1 - D(\boldsymbol x;\phi ) \]

給定一個樣本\((x,y),y= \left \{ 1,0 \right \}\),表示其來自於\(p_{r}(\boldsymbol x)\)還是\(p_{\theta}(\boldsymbol x)\),判別網路的目標函式為最小化交叉熵,即:

\[\mathop{min}_{\phi }-\left ( \mathbb{E}_{x}\left [ ylogp(y=1| \boldsymbol x) + (1-y)log p(y=0| \boldsymbol x)\right ] \right ) \]

假設分佈\(p(\boldsymbol x)\)是由分佈\(p_{r}(\boldsymbol x)\)和分佈\(p_{\theta}(\boldsymbol x)\)等比例混合而成,即\(p(\boldsymbol x) = \frac{1}{2} * \left (p_{r}(\boldsymbol x) + p_{\theta}(\boldsymbol x) \right )\),則上式等價於:

\[\mathop{max}_{\phi } \mathbb{E}_{\boldsymbol x \sim p_{r}(\boldsymbol x)}\left [ logD(\boldsymbol x ;\phi ) \right ] + \mathbb{E}_{\boldsymbol x ^{'} \sim p_{\theta}(\boldsymbol x ^{'})}\left [ log\left ( 1 - D(\boldsymbol x ^{'} ;\phi ) \right ) \right ] \]\[=\mathop{max}_{\phi } \mathbb{E}_{\boldsymbol x \sim p_{r}(\boldsymbol x)}\left [ logD(\boldsymbol x ;\phi ) \right ] + \mathbb{E}_{\boldsymbol z \sim p(\boldsymbol z )}\left [ log\left ( 1 - D(G(\boldsymbol z ;\theta ) ;\phi ) \right ) \right ] \]

其中\(\theta\)\(\phi\)分別是生成網路和判別網路的引數。

生成網路

生成網路的目標剛好和判別網路相反,即讓判別網路將自己生成的樣本判別為真是樣本。

\[\mathop{max}_{\theta } \mathbb{E}_{\boldsymbol z \sim p(\boldsymbol z)}\left [ logD \left (G (\boldsymbol z; \theta ) ;\phi \right ) \right ] \]\[=\mathop{min}_{\theta } \mathbb{E}_{\boldsymbol z \sim p(\boldsymbol z)}\left [ log(1 - D \left (G (\boldsymbol z; \theta ) ;\phi \right )) \right ] \]

兩個目標函式是等價的,但一般使用前者,因為其梯度性質更好。

訓練

和單目標的優化任務相比,生成對抗網路的兩個網路的優化目標剛好相反。因此生成對抗網路的訓練比較難,往往不太穩定. 一般情況下,需要平衡兩個網路的能力。對於判別網路來說,一開始的判別能力不能太強,否則難以提升生成網路的能力。但是,判別網路的判別能力也不能太弱,否則針對它訓練的生成網路也不會太好。 在訓練時需要使用一些技巧,使得在每次迭代中,判別網路比生成網路的能力強一些,但又不能強太多。具體做法是,判別網路更新\(K\)次,生成網路更新1次。

生成對抗網路訓練過程

程式碼實現

hyperparam.py檔案
超引數配置模組

import argparse


class HyperParam:
    def __init__(self):
        self.parse = argparse.ArgumentParser()
        self.parse.add_argument("--latent_dim", type=int, default=5)  # 隱含變數的維度
        self.parse.add_argument("--data_dim", type=int, default=10)  # 觀測變數的維度
        self.parse.add_argument("--data_size", type=int, default=10000)  # 樣本數
        self.parse.add_argument("--g_lr", type=float, default=0.001)
        self.parse.add_argument("--d_lr", type=float, default=0.001)
        self.parse.add_argument("--epochs", type=int, default=300)
        self.parse.add_argument("--K", type=int, default=5)
        self.parse.add_argument("--sample_size", type=int, default=128)
        self.parse.add_argument("--batch_size", type=int, default=128)

gan.py檔案
GAN的實現部分

import numpy as np
import torch
from hyperparam import HyperParam
import torch.nn as nn
import torch.utils.data as Data
import matplotlib.pyplot as plt

np.random.seed(1000)
torch.manual_seed(1000)


def get_real_data(data_dim, data_size, batch_size):
    base = np.linspace(-1, 1, data_dim)
    a = np.random.uniform(8, 15, data_size).reshape(-1, 1)
    c = np.random.uniform(0.5, 10, data_size).reshape(-1, 1)

    # 構造真實資料
    X = a * np.power(base, 2) + c
    X = torch.from_numpy(X).type(torch.float32)
    data_set = Data.TensorDataset(X)
    data_loader = Data.DataLoader(dataset=data_set,
                                  batch_size=batch_size)

    return base, data_loader


class GAN(nn.Module):
    def __init__(self, latent_dim, data_dim, K, sample_size):
        super().__init__()
        self.latent_dim = latent_dim
        self.data_dim = data_dim
        self.K = K
        self.sample_size = sample_size

        self.g = self._generator()
        self.d = self._discriminator()

        self.g_optimizer = torch.optim.Adam(self.g.parameters(), lr=0.001)
        self.d_optimizer = torch.optim.Adam(self.d.parameters(), lr=0.001)

    def _generator(self):
        model = nn.Sequential(
            nn.Linear(self.latent_dim, 64),
            nn.ReLU(),
            nn.Linear(64, self.data_dim)
        )
        return model

    def _discriminator(self):
        model = nn.Sequential(
            nn.Linear(self.data_dim, 64),
            nn.ReLU(),
            nn.Linear(64, 1),
            nn.Sigmoid()
        )
        return model

    def d_loss_fn(self, pred_data_result, true_data_result):
        return -torch.mean(torch.log(true_data_result) + torch.log(1 - pred_data_result))

    def g_loss_fn(self, pred_data_result):
        return -torch.mean(torch.log(pred_data_result))

    def train_d(self, true_data):
        sample_size = true_data.shape[0]
        for i in range(self.K):
            # 取樣
            sample = torch.rand(sample_size, self.latent_dim)
            # 生成
            fake_data = self.g(sample)
            # 生成資料的判定結果
            fake_data_result = self.d(fake_data)

            # 真實資料的判定結果
            true_data_result = self.d(true_data)

            loss = self.d_loss_fn(fake_data_result, true_data_result)
            self.d_optimizer.zero_grad()
            loss.backward()
            self.d_optimizer.step()

    def train_g(self):
        # 取樣
        sample = torch.rand(self.sample_size, self.latent_dim)
        # 生成
        fake_data = self.g(sample)
        # 生成資料的判定結果
        fake_data_result = self.d(fake_data)

        loss = self.g_loss_fn(fake_data_result)
        self.g_optimizer.zero_grad()
        loss.backward()
        self.g_optimizer.step()

    def step(self, true_data):
        self.train_d(true_data)  # 先訓練判別器
        self.train_g()  # 再訓練生成器


def train(epochs, latent_dim, data_dim, K, sample_size, data_loader, base):
    print('正在訓練......')
    model = GAN(latent_dim, data_dim, K, sample_size)

    plt.ion()
    for epoch in range(epochs):
        for true_data in data_loader:
            model.step(true_data[0])  # [128, 15]
        if (epoch + 1) % 50 == 0:
            print('epoch: [{}/{}]'.format(epoch + 1, epochs))
            # 取樣
            sample = torch.rand(1, latent_dim)
            # 生成
            fake_data = model.g(sample)
            plt.cla()
            plt.plot(base, fake_data.data.numpy().flatten())
            plt.show()
            plt.pause(0.1)
    plt.ioff()
    plt.show()

    torch.save(model.state_dict(), 'gan_param.pkl')
    print('模型儲存成功')


if __name__ == "__main__":
    instance = HyperParam()
    hp = instance.parse.parse_args()
    epochs = hp.epochs
    latent_dim = hp.latent_dim
    data_dim = hp.data_dim
    K = hp.K
    sample_size = hp.sample_size
    data_size = hp.data_size
    batch_size = hp.batch_size

    base, data_loader = get_real_data(data_dim, data_size, batch_size)
    train(epochs, latent_dim, data_dim, K, sample_size, data_loader, base)

執行結果及分析

執行結果

從圖中可以看出,從左到右,生成模型繪製二次曲線的能力越來越強了,訓練500個epoch之後,生成的圖形比較接近真實的二次曲線。

結果分析

實際執行程式時會發現,GAN的生成效果對啟用函式和超引數的依賴非常大,特別是超引數K(訓練K次判別器之後再訓練一次生成器)的取值,如果K的取值稍微不合理,那麼會直接導致生成器的損失太大,無法繼續優化下去。此外,GAN需要足夠的多的樣本學習,特別是如果隱變數維度較多的話,需要更多的樣本才有可能學得比較好的模型;模型訓練過程中存在明顯的震盪現象。

GAN的優缺點分析

優點

  • GAN是一種生成式模型,相比較其他生成模型(玻爾茲曼機和GSNs)只用到了反向傳播,而不需要複雜的馬爾科夫鏈。
  • 相比其他所有模型, GAN可以產生更加清晰,真實的樣本。
  • GAN採用的是一種無監督的學習方式訓練,可以被廣泛用在無監督學習和半監督學習領域。
  • 相比於變分自編碼器, GANs沒有引入任何決定性偏置( deterministic bias),變分方法引入決定性偏置,因為他們優化對數似然的下界,而不是似然度本身,這看起來導致了VAEs生成的例項比GANs更模糊。
  • 相比VAE, GANs沒有變分下界,如果鑑別器訓練良好,那麼生成器可以完美的學習到訓練樣本的分佈。換句話說,GANs是漸進一致的,但是VAE是有偏差的。

缺點

  • GAN不適合處理離散形式的資料,比如文字。
  • GAN存在訓練不穩定、梯度消失、模式崩潰的問題(目前已解決)

關於GAN的一些問題

模式崩潰的原因

一般出現在GAN訓練不穩定的時候,具體表現為生成出來的結果非常差,但是即使加長訓練時間後也無法得到很好的改善。
具體原因可以解釋如下:GAN採用的是對抗訓練的方式,G的梯度更新來自D,所以G生成的好不好,得看D怎麼說。具體就是G生成一個樣本,交給D去評判,D會輸出生成的假樣本是真樣本的概率(0-1),相當於告訴G生成的樣本有多大的真實性,G就會根據這個反饋不斷改善自己,提高D輸出的概率值。但是如果某一次G生成的樣本可能並不是很真實,但是D給出了正確的評價,或者是G生成的結果中一些特徵得到了D的認可,這時候G就會認為我輸出的正確的,那麼接下來我就這樣輸出肯定D還會給出比較高的評價,實際上G生成的並不怎麼樣,但是他們兩個就這樣自我欺騙下去了,導致最終生成結果缺失一些資訊,特徵不全。

為什麼優化器不常用SGD

  • SGD容易震盪,容易使GAN訓練不穩定。
  • GAN的目的是在高維非凸的引數空間中找到納什均衡點,GAN的納什均衡點是一個鞍點,但是SGD只會找到區域性極小值,因為SGD解決的是一個尋找最小值的問題,GAN是一個博弈問題。

為什麼不適合處理文字資料

  • 文字資料相比較圖片資料來說是離散的,因為對於文字來說,通常需要將一個詞對映為一個高維的向量,最終預測的輸出是一個one-hot向量,假設softmax的輸出是(0.2, 0.3, 0.1,0.2,0.15,0.05)那麼變為onehot是(0,1,0,0,0,0),如果softmax輸出是(0.2, 0.25, 0.2, 0.1,0.15,0.1 ),one-hot仍然是(0, 1, 0, 0, 0, 0),所以對於生成器來說,G輸出了不同的結果但是D給出了同樣的判別結果,並不能將
  • GAN的損失函式是JS散度,JS散度不適合衡量不想交分佈之間的距離。

訓練GAN的技巧

  • 輸入規範化到(-1,1)之間,最後一層的啟用函式使用tanh(BEGAN除外)
  • 使用wassertein GAN的損失函式
  • 如果有標籤資料的話,儘量使用標籤,也有人提出使用反轉標籤效果很好,另外使用標籤平滑,單邊標籤平滑或者雙邊標籤平滑
  • 使用mini-batch norm, 如果不用batch norm 可以使用instance norm 或者weight norm
  • 避免使用RELU和pooling層,減少稀疏梯度的可能性,可以使用leakrelu啟用函式
  • 優化器儘量選擇ADAM,學習率不要設定太大,初始1e-4可以參考,另外可以隨著訓練進行不斷縮小學習率
  • 給D的網路層增加高斯噪聲,相當於是一種正則

參考