生成對抗網路(GAN)系列(一)
生成式模型的作用
密度估計
給定一組資料\(D=\left \{ x^{n} \right \}^{N}_{n=1}\),假設它們都是獨立地從相同的概率密度函式為\(p_{r}(x)\)的未知分佈中產生的。密度估計是根據資料集\(D\)來估計其概率密度函式\(p_{\theta}(x)\)。
在機器學習中,密度估計是一類無監督學習問題。比如在手寫體數字影象的密度估計問題中,我們將影象表示為一個隨機變數\(X\),其中每一維都表示一個畫素值。假設手寫體數字影象都服從一個未知的分佈\(p_{r}{x}\),希望通過一些觀測樣本來估計其分佈。但是手寫體數字影象中不同畫素之間存在複雜的依賴關係,很難用一個明確的圖模型來描述其依賴關係,所以直接建模\(p_{r}{x}\)
如果要建模隱含變數的分佈,就需要用EM演算法來進行密度估計,而在EM演算法中,需要估計條件分佈\(p(x|z; \theta)\)以及後驗概率分佈\(p(z|x; \theta)\)。當這兩個分佈比較複雜時,就可以利用神經網路來建模(如變分自編碼器)。
生成樣本
在知道\(p_{\theta}(z)\)和得到\(p_{\theta}(x|z)\)之後就可以生成新的資料:
- 從隱變數的先驗分佈\(p_{\theta}(z)\)中取樣,得到樣本\(z\)。
- 根據條件概率分佈\(p_{\theta}(x|z)\)進行取樣,得到新的樣本\(x\)
生成對抗網路
本文的重點是生成對抗網路(GAN)。與一般的生成式模型(如VAE、DQN)不同,GAN並不直接建模\(p(x)\),而是直接通過一個神經網路學習從隱變數\(z\)到資料\(x\)的對映,稱為生成器;然後將生成的樣本交給判別網路判斷是否是真實的樣本。可以看出,生成網路和判別網路的訓練是彼此依存、交替進行的。
判別網路
判別網路\(D(\boldsymbol x;\phi )\)的目標是區分出一個樣本\(\boldsymbol x\)是來自於真實分佈\(p_{r}(\boldsymbol x)\)還是來自於生成模型\(p_{\theta}(\boldsymbol x)\)。由此可見,判別網路實際上是一個二分類的分類器。用標籤\(y=1\)來表示樣本來自真實分佈,\(y=0\)表示樣本來自生成模型,判別網路\(D(\boldsymbol x;\phi )\)的輸出為\(\boldsymbol x\)屬於真實資料分佈的概率,即:
\[p(y=1 | x) = D(\boldsymbol x;\phi ) \]樣本來自生成模型的概率為:
\[p(y=0 | x) = 1 - D(\boldsymbol x;\phi ) \]給定一個樣本\((x,y),y= \left \{ 1,0 \right \}\),表示其來自於\(p_{r}(\boldsymbol x)\)還是\(p_{\theta}(\boldsymbol x)\),判別網路的目標函式為最小化交叉熵,即:
\[\mathop{min}_{\phi }-\left ( \mathbb{E}_{x}\left [ ylogp(y=1| \boldsymbol x) + (1-y)log p(y=0| \boldsymbol x)\right ] \right ) \]假設分佈\(p(\boldsymbol x)\)是由分佈\(p_{r}(\boldsymbol x)\)和分佈\(p_{\theta}(\boldsymbol x)\)等比例混合而成,即\(p(\boldsymbol x) = \frac{1}{2} * \left (p_{r}(\boldsymbol x) + p_{\theta}(\boldsymbol x) \right )\),則上式等價於:
\[\mathop{max}_{\phi } \mathbb{E}_{\boldsymbol x \sim p_{r}(\boldsymbol x)}\left [ logD(\boldsymbol x ;\phi ) \right ] + \mathbb{E}_{\boldsymbol x ^{'} \sim p_{\theta}(\boldsymbol x ^{'})}\left [ log\left ( 1 - D(\boldsymbol x ^{'} ;\phi ) \right ) \right ] \]\[=\mathop{max}_{\phi } \mathbb{E}_{\boldsymbol x \sim p_{r}(\boldsymbol x)}\left [ logD(\boldsymbol x ;\phi ) \right ] + \mathbb{E}_{\boldsymbol z \sim p(\boldsymbol z )}\left [ log\left ( 1 - D(G(\boldsymbol z ;\theta ) ;\phi ) \right ) \right ] \]其中\(\theta\)和\(\phi\)分別是生成網路和判別網路的引數。
生成網路
生成網路的目標剛好和判別網路相反,即讓判別網路將自己生成的樣本判別為真是樣本。
\[\mathop{max}_{\theta } \mathbb{E}_{\boldsymbol z \sim p(\boldsymbol z)}\left [ logD \left (G (\boldsymbol z; \theta ) ;\phi \right ) \right ] \]\[=\mathop{min}_{\theta } \mathbb{E}_{\boldsymbol z \sim p(\boldsymbol z)}\left [ log(1 - D \left (G (\boldsymbol z; \theta ) ;\phi \right )) \right ] \]兩個目標函式是等價的,但一般使用前者,因為其梯度性質更好。
訓練
和單目標的優化任務相比,生成對抗網路的兩個網路的優化目標剛好相反。因此生成對抗網路的訓練比較難,往往不太穩定. 一般情況下,需要平衡兩個網路的能力。對於判別網路來說,一開始的判別能力不能太強,否則難以提升生成網路的能力。但是,判別網路的判別能力也不能太弱,否則針對它訓練的生成網路也不會太好。 在訓練時需要使用一些技巧,使得在每次迭代中,判別網路比生成網路的能力強一些,但又不能強太多。具體做法是,判別網路更新\(K\)次,生成網路更新1次。
程式碼實現
hyperparam.py檔案
超引數配置模組
import argparse
class HyperParam:
def __init__(self):
self.parse = argparse.ArgumentParser()
self.parse.add_argument("--latent_dim", type=int, default=5) # 隱含變數的維度
self.parse.add_argument("--data_dim", type=int, default=10) # 觀測變數的維度
self.parse.add_argument("--data_size", type=int, default=10000) # 樣本數
self.parse.add_argument("--g_lr", type=float, default=0.001)
self.parse.add_argument("--d_lr", type=float, default=0.001)
self.parse.add_argument("--epochs", type=int, default=300)
self.parse.add_argument("--K", type=int, default=5)
self.parse.add_argument("--sample_size", type=int, default=128)
self.parse.add_argument("--batch_size", type=int, default=128)
gan.py檔案
GAN的實現部分
import numpy as np
import torch
from hyperparam import HyperParam
import torch.nn as nn
import torch.utils.data as Data
import matplotlib.pyplot as plt
np.random.seed(1000)
torch.manual_seed(1000)
def get_real_data(data_dim, data_size, batch_size):
base = np.linspace(-1, 1, data_dim)
a = np.random.uniform(8, 15, data_size).reshape(-1, 1)
c = np.random.uniform(0.5, 10, data_size).reshape(-1, 1)
# 構造真實資料
X = a * np.power(base, 2) + c
X = torch.from_numpy(X).type(torch.float32)
data_set = Data.TensorDataset(X)
data_loader = Data.DataLoader(dataset=data_set,
batch_size=batch_size)
return base, data_loader
class GAN(nn.Module):
def __init__(self, latent_dim, data_dim, K, sample_size):
super().__init__()
self.latent_dim = latent_dim
self.data_dim = data_dim
self.K = K
self.sample_size = sample_size
self.g = self._generator()
self.d = self._discriminator()
self.g_optimizer = torch.optim.Adam(self.g.parameters(), lr=0.001)
self.d_optimizer = torch.optim.Adam(self.d.parameters(), lr=0.001)
def _generator(self):
model = nn.Sequential(
nn.Linear(self.latent_dim, 64),
nn.ReLU(),
nn.Linear(64, self.data_dim)
)
return model
def _discriminator(self):
model = nn.Sequential(
nn.Linear(self.data_dim, 64),
nn.ReLU(),
nn.Linear(64, 1),
nn.Sigmoid()
)
return model
def d_loss_fn(self, pred_data_result, true_data_result):
return -torch.mean(torch.log(true_data_result) + torch.log(1 - pred_data_result))
def g_loss_fn(self, pred_data_result):
return -torch.mean(torch.log(pred_data_result))
def train_d(self, true_data):
sample_size = true_data.shape[0]
for i in range(self.K):
# 取樣
sample = torch.rand(sample_size, self.latent_dim)
# 生成
fake_data = self.g(sample)
# 生成資料的判定結果
fake_data_result = self.d(fake_data)
# 真實資料的判定結果
true_data_result = self.d(true_data)
loss = self.d_loss_fn(fake_data_result, true_data_result)
self.d_optimizer.zero_grad()
loss.backward()
self.d_optimizer.step()
def train_g(self):
# 取樣
sample = torch.rand(self.sample_size, self.latent_dim)
# 生成
fake_data = self.g(sample)
# 生成資料的判定結果
fake_data_result = self.d(fake_data)
loss = self.g_loss_fn(fake_data_result)
self.g_optimizer.zero_grad()
loss.backward()
self.g_optimizer.step()
def step(self, true_data):
self.train_d(true_data) # 先訓練判別器
self.train_g() # 再訓練生成器
def train(epochs, latent_dim, data_dim, K, sample_size, data_loader, base):
print('正在訓練......')
model = GAN(latent_dim, data_dim, K, sample_size)
plt.ion()
for epoch in range(epochs):
for true_data in data_loader:
model.step(true_data[0]) # [128, 15]
if (epoch + 1) % 50 == 0:
print('epoch: [{}/{}]'.format(epoch + 1, epochs))
# 取樣
sample = torch.rand(1, latent_dim)
# 生成
fake_data = model.g(sample)
plt.cla()
plt.plot(base, fake_data.data.numpy().flatten())
plt.show()
plt.pause(0.1)
plt.ioff()
plt.show()
torch.save(model.state_dict(), 'gan_param.pkl')
print('模型儲存成功')
if __name__ == "__main__":
instance = HyperParam()
hp = instance.parse.parse_args()
epochs = hp.epochs
latent_dim = hp.latent_dim
data_dim = hp.data_dim
K = hp.K
sample_size = hp.sample_size
data_size = hp.data_size
batch_size = hp.batch_size
base, data_loader = get_real_data(data_dim, data_size, batch_size)
train(epochs, latent_dim, data_dim, K, sample_size, data_loader, base)
執行結果及分析
執行結果
從圖中可以看出,從左到右,生成模型繪製二次曲線的能力越來越強了,訓練500個epoch之後,生成的圖形比較接近真實的二次曲線。
結果分析
實際執行程式時會發現,GAN的生成效果對啟用函式和超引數的依賴非常大,特別是超引數K(訓練K次判別器之後再訓練一次生成器)的取值,如果K的取值稍微不合理,那麼會直接導致生成器的損失太大,無法繼續優化下去。此外,GAN需要足夠的多的樣本學習,特別是如果隱變數維度較多的話,需要更多的樣本才有可能學得比較好的模型;模型訓練過程中存在明顯的震盪現象。
GAN的優缺點分析
優點
- GAN是一種生成式模型,相比較其他生成模型(玻爾茲曼機和GSNs)只用到了反向傳播,而不需要複雜的馬爾科夫鏈。
- 相比其他所有模型, GAN可以產生更加清晰,真實的樣本。
- GAN採用的是一種無監督的學習方式訓練,可以被廣泛用在無監督學習和半監督學習領域。
- 相比於變分自編碼器, GANs沒有引入任何決定性偏置( deterministic bias),變分方法引入決定性偏置,因為他們優化對數似然的下界,而不是似然度本身,這看起來導致了VAEs生成的例項比GANs更模糊。
- 相比VAE, GANs沒有變分下界,如果鑑別器訓練良好,那麼生成器可以完美的學習到訓練樣本的分佈。換句話說,GANs是漸進一致的,但是VAE是有偏差的。
缺點
- GAN不適合處理離散形式的資料,比如文字。
- GAN存在訓練不穩定、梯度消失、模式崩潰的問題(目前已解決)
關於GAN的一些問題
模式崩潰的原因
一般出現在GAN訓練不穩定的時候,具體表現為生成出來的結果非常差,但是即使加長訓練時間後也無法得到很好的改善。
具體原因可以解釋如下:GAN採用的是對抗訓練的方式,G的梯度更新來自D,所以G生成的好不好,得看D怎麼說。具體就是G生成一個樣本,交給D去評判,D會輸出生成的假樣本是真樣本的概率(0-1),相當於告訴G生成的樣本有多大的真實性,G就會根據這個反饋不斷改善自己,提高D輸出的概率值。但是如果某一次G生成的樣本可能並不是很真實,但是D給出了正確的評價,或者是G生成的結果中一些特徵得到了D的認可,這時候G就會認為我輸出的正確的,那麼接下來我就這樣輸出肯定D還會給出比較高的評價,實際上G生成的並不怎麼樣,但是他們兩個就這樣自我欺騙下去了,導致最終生成結果缺失一些資訊,特徵不全。
為什麼優化器不常用SGD
- SGD容易震盪,容易使GAN訓練不穩定。
- GAN的目的是在高維非凸的引數空間中找到納什均衡點,GAN的納什均衡點是一個鞍點,但是SGD只會找到區域性極小值,因為SGD解決的是一個尋找最小值的問題,GAN是一個博弈問題。
為什麼不適合處理文字資料
- 文字資料相比較圖片資料來說是離散的,因為對於文字來說,通常需要將一個詞對映為一個高維的向量,最終預測的輸出是一個one-hot向量,假設softmax的輸出是(0.2, 0.3, 0.1,0.2,0.15,0.05)那麼變為onehot是(0,1,0,0,0,0),如果softmax輸出是(0.2, 0.25, 0.2, 0.1,0.15,0.1 ),one-hot仍然是(0, 1, 0, 0, 0, 0),所以對於生成器來說,G輸出了不同的結果但是D給出了同樣的判別結果,並不能將
- GAN的損失函式是JS散度,JS散度不適合衡量不想交分佈之間的距離。
訓練GAN的技巧
- 輸入規範化到(-1,1)之間,最後一層的啟用函式使用tanh(BEGAN除外)
- 使用wassertein GAN的損失函式
- 如果有標籤資料的話,儘量使用標籤,也有人提出使用反轉標籤效果很好,另外使用標籤平滑,單邊標籤平滑或者雙邊標籤平滑
- 使用mini-batch norm, 如果不用batch norm 可以使用instance norm 或者weight norm
- 避免使用RELU和pooling層,減少稀疏梯度的可能性,可以使用leakrelu啟用函式
- 優化器儘量選擇ADAM,學習率不要設定太大,初始1e-4可以參考,另外可以隨著訓練進行不斷縮小學習率
- 給D的網路層增加高斯噪聲,相當於是一種正則
參考
- 《神經網路與深度學習》——邱錫鵬著
- 莫煩python
- GAN的原理 優缺點