1. 程式人生 > 實用技巧 >GAN——生成式對抗網路

GAN——生成式對抗網路

GAN的定義

  GAN是一個評估和學習生成模型的框架。生成模型的目標是學習到輸入樣本的分佈,用來生成樣本。GAN和傳統的生成模型不同,使用兩個內建模型以“對抗”的方式來使學習分佈不斷接近輸入樣本分佈。兩個模型一個是生成模型(Generative model),用來生成樣本;另一個是判別模型(Discriminative model),產生判斷樣本是真實而不是來自生成模型的概率。生成模型並不直接學習輸入樣本的分佈,而是通過“欺騙”判別模型的方式提高輸入分佈的逼近程度;判別模型則是使用生成樣本和真實樣本來提高判別準確率。

  對於生成模型$G$和判別模型$D$,GAN的優化式的如下:

$\min\limits_{G}\max\limits_{D} V(D,G)$

$ V(D,G) = E_{x\sim p_{data}}[\log_{}D(x)] + E_{z\sim p_z}[\log_{}(1-D(G(z)))]$

  其中$p_{data}$是樣本的真實分佈。比如對於某個解析度的圖片來說,這個分佈基於這個解析度上的所有圖片。注意!即使是亂碼圖片,它也是有概率密度的,只不過很小很小而已。$p_z$是隨機數$z$的分佈,通常用高斯分佈(文章用的是均勻分佈,這是最早的文章);$G(z)$就是生成器基於這個隨機數生成的樣本。$D(x)$是判別器判斷樣本$x$為真實樣本的概率。

  使用梯度下降法進行優化的過程如下:

  每次分別隨機拿到$m$個真實和生成樣本用來對函式($\theta_d$、$\theta_g$分別包含在$D$和$G$中)

$\displaystyle f(\theta_d) = \frac{1}{m}\sum\limits_{i=1}^{m}[\log_{}D(x^{(i)})+\log_{}(1-D(G(z^{(i)})))]$

  梯度上升,也就是優化判別模型;再生成$m$個樣本用來對函式

$\displaystyle g(\theta_g)= \frac{1}{m}\sum\limits_{i=1}^{m}[\log_{}(1-D(G(z^{(i)})))]$

  梯度下降也就是優化生成模型。最終二者都達到最優。

  以下是擬合的過程圖:

  黑點線是樣本$x$的真實分佈,綠線是樣本$x$的生成模型分佈,藍虛線是判別模型判斷$x$屬於真實的概率,下方的$z$是均勻分佈隨機數$z$到生成樣本$x$的對映。

  a圖是初始化時,判別模型$D$和生成模型$G$都很差。

  b圖是取樣本來更新$D$,$D$在此刻變為最優。也就是說,在當前的$G$下,對於每個$x$,都能正確得出它是真實樣本的概率:

$\displaystyle D(x) = \frac{p_{data}(x)}{p_{data}(x)+p_g(x)}$,

  證明在後面,不過想想也是這麼一回事。比如看綠線和黑點線中間的交叉點,此時$x$的真實概率為0.5。

  c圖是更新$G$,$G$在此刻$D$的基礎上變得不錯了。

  d圖是一直迭代到最後,$G$和真實分佈一模一樣,而$D$的判斷概率全是0.5。但是,一模一樣也不是很好。因為樣本集總是有限的,並不能完全契合樣本全體的分佈,所以如果生成分佈和樣本集分佈一模一樣的話可能會過擬合。

全域性最優

  對任意給定的$G$,最優的$D$對每個樣本$x$,都有:

$D_G^*(x) = \displaystyle\frac{p_{data}(x)}{p_{data}(x)+p_g(x)}$

  這是因為最優的$D$最大化關於$\theta_d$的函式:

$\displaystyle V(G,D) = \int_x p_{data}(x)\log_{}(D(x)) + p_g(x)\log_{}(1-D(x))dx$  

  也就是對於每個$x$,這個積分內部函式都取最大值。對於函式

$h(y) = a\log_{}(y)+b\log_{}(1-y),a\ge 0,b\ge 0$

  在$0< y < y^*$時,$h'(y)$大於零;$y^*< y < 1$,$h'(y)$小於零。所以$h(y)$在

$\displaystyle y^*=\frac{a}{a+b}$

  時最大。因此得證。

  假如$G$訓練到了最優,也就是輸出分佈與輸入樣本分佈相同,即$p_{data}=p_g$,而$D$也最優時,有:

$\displaystyle V(D,G) = E_{x\sim p_{data}}\left[\log_{}\frac{p_{data}(x)}{p_{data}(x)+p_g(x)}\right] + E_{x\sim p_g}\left[\log_{}\frac{p_{g}(x)}{p_{data}(x)+p_g(x)}\right]$

$\displaystyle= E_{x\sim p_{data}}\left[\log_{}\frac{1}{2}\right] + E_{x\sim p_g}\left[\log_{}\frac{1}{2}\right]=-\log_{}4$

CGAN

  CGAN(Conditional GAN)是GAN的一種基本變通,這裡記錄一下。相對於基本GAN的生成器和判別器,輸入分別只有隨機抽樣和樣本,而CGAN則可以附帶條件。CGAN生成器的輸入除了隨機抽樣外,還可以附加樣本的一些特徵,從而可以更加精確地生成我們期望的生成樣本。判別器則是輸入樣本和對應的特徵,聯合這兩者進行判斷樣本的“真實性”。

  比如用CGAN訓練MNIST時,我們想要讓生成器能生成我們期望的數字。生成器的輸入就是隨機抽樣+對應數字的one-hot編碼,而判別器的輸入就是生成的樣本或真實樣本+對應數字的one-hot編碼。所以CGAN的優化函式就在GAN的基礎上改改:

$\max\limits_G\min\limits_D V(D,G) = E_{x\sim p_{data}}[\log_{}D(x|y)] + E_{z\sim p_z}[\log_{}(1-D(G(z|y)|y))]$

  其中$y$是$x$的標籤。上面$D$中表示的好像是條件概率,我覺得也可以直接理解為聯合概率。生成器和判別器只需將它們的兩個輸入concatenate,後面的層就和GAN類似了。另外,上式沒有對樣本和標籤不匹配的情況進行限制,論文中也沒有寫。這樣的話,模型就可能生成比較真實但與標籤不符的樣本。所以訓練時判別器還應該懲罰真實但標籤錯誤的輸入。

  下面用MNIST訓練CGAN來生成數字,模型結構是用CGAN論文中的。我原本是想用卷積網路來搭建,然而迭代了幾萬次都生成不出有點像數字的圖,最終放棄。而仿照論文用全連線層搭建的模型,雖然也不是特別“真”,至少比我原來的模型效果好多了。下面是生成的數字圖:

  一共迭代了1100次,每次迭代使用100個樣本對生成器和判別器進行訓練。隨著迭代次數的增加,生成圖片的效果逐漸變好,又逐漸崩壞,然後又逐漸變好,如此反覆迴圈,所以要把握迭代停止的時機。理論上來講,如果一直迭代下去,最終是會平穩下來的。但是我迭代到幾千次甚至上萬次,生成的圖片效果依舊沒有變得很好,具體原因不清楚,還有待發掘。

  以下是訓練程式碼:

#%%生成器
from keras import layers,Input,Model,utils,activations
import numpy as np

sample_num = 200
Input_sampling = Input(shape=[sample_num])
Input_label = Input(shape=[10])
   
x1 = layers.Dense(sample_num,activation='relu')(Input_sampling)
x2 = layers.Dense(1000,activation='relu')(Input_label)
x = layers.concatenate([x1,x2])
x = layers.Dropout(0.5)(x)
x = layers.Dense(28*28,activation='sigmoid')(x) 
x = layers.Reshape([28,28,1])(x)

generator = Model([Input_label,Input_sampling],x) 
generator.summary()
utils.plot_model(generator)
#%%判別器
Input_img = Input(shape=[28,28,1])

x1 = layers.Reshape([28*28])(Input_img)
x1 = layers.MaxoutDense(240,5)(x1)
x2 = layers.MaxoutDense(50,5)(Input_label)
x = layers.concatenate([x1,x2])
x = layers.MaxoutDense(240,4)(x)
x = layers.Dropout(0.5)(x)
x = layers.Dense(1,activation='sigmoid')(x)

discriminator = Model([Input_label,Input_img],x) 
discriminator.summary()
utils.plot_model(discriminator)
#%%合併模型GAN
x = generator([Input_label,Input_sampling])
x = discriminator([Input_label,x])
gan = Model([Input_label,Input_sampling],x) 
#%%資料預處理
from keras.datasets import mnist 
import numpy as np 
import matplotlib.pyplot as plt
(train_data,train_labels),(test_data,test_labels) = mnist.load_data() 
def label_to_one_hot(labels):
  l = np.zeros([len(labels),10])
  for i in range(len(labels)):
    l[i,labels[i]]=1
  return l
train_data = train_data[:,:,:,np.newaxis].astype('float')/255
test_data = test_data[:,:,:,np.newaxis].astype('float')/255
train_labels = label_to_one_hot(train_labels)
test_labels = label_to_one_hot(test_labels)   
plt.imshow(train_data[0,:,:,0])    
#%%編譯模型
from tensorflow.keras import optimizers,losses
import matplotlib.pyplot as plt

generator.trainable = True
discriminator.trainable = False 
gan_optimizer = optimizers.Adam()
gan.compile(
  optimizer=gan_optimizer,
  loss='binary_crossentropy') 
discriminator.trainable = True
d_optimizer = optimizers.Adam()
discriminator.compile(
  optimizer=d_optimizer,
  loss='binary_crossentropy')    
#%%訓練
def get_samples():
  return np.random.random([batch_size,sample_num])*2-1
def train_generator(batch_size,if_show_loss): 
  samples = get_samples()
  labels = np.zeros([batch_size,10]) 
  judges = np.ones(batch_size) - np.abs(np.random.normal(scale=0.05,loc = 0,size = batch_size)) 
  for i in labels:
    i[np.random.randint(10)] = 1. 
  gan.fit([labels,samples],judges,verbose=if_show_loss) 
def train_discriminator(data,labels_true_right,batch_size,if_show_loss): 
  #生成器生成影象
  samples = get_samples()
  labels_fake = np.zeros([batch_size,10]) 
  for i in labels_fake:
    i[np.random.randint(10)] = 1. 
  fake_imgs = generator.predict([labels_fake,samples])  
  #獲取錯誤標籤真影象
  s = np.linspace(0,9,10).astype('int')
  lebals_true_wrong = np.zeros_like(labels_true_right)
  for i in range(batch_size): 
    p = np.ones(10)/9  
    p[np.argmax(labels_true_right[i])] = 0 
    lebals_true_wrong[i,np.random.choice(s,1, p=p)] = 1
  #將輸入拼接
  in_imgs = np.concatenate([fake_imgs,data,data],axis = 0)
  in_labels = np.concatenate(
    [labels_fake,lebals_true_wrong,labels_true_right],
    axis = 0) 
  judges_wrong = np.zeros(batch_size*2) + np.random.normal(scale=0.05,loc = 0,size = batch_size*2) 
  judges_right = np.ones(batch_size) - np.random.normal(scale=0.05,loc = 0,size = batch_size) 
  train_judges = np.concatenate([judges_wrong,judges_right],axis=0)
  
  discriminator.fit([in_labels,in_imgs],train_judges,verbose=if_show_loss)
def save_img_and_model(num,i): 
  label = np.zeros([1,10])
  label[0,num] = 1
  img = generator.predict([label,get_samples()])
  plt.imshow( img[0,:,:,0],cmap='bone')
  plt.show( )

  generator.save('generator.h5')
  discriminator.save('discriminator.h5')

epochs = 10000
batch_size = 500
train_size = 20000
now_train = 0
for i in range(epochs):
  print(i)
  if_show_loss = False
  if i % 20 == 0:
    if_show_loss = True 
    save_img_and_model(np.random.randint(10),i)
  train_generator(batch_size,if_show_loss)
  train_discriminator(
    train_data[now_train:now_train+batch_size],
    train_labels[now_train:now_train+batch_size],
    batch_size,if_show_loss)
  now_train = (now_train + batch_size)%train_size 

參考文獻

  Generative Adversarial Networks

  ConditionalGenerative Adversarial Nets