1. 程式人生 > 實用技巧 >GAN(生成對抗網路)以及keras實現

GAN(生成對抗網路)以及keras實現

由於筆者水平有限,如有錯,歡迎指正。

論文原文:https://arxiv.org/pdf/1406.2661.pdf


0 GAN的思想

GAN,全稱為 Generative Adversarial Nets,直譯為生成式對抗網路,是一種非監督式模型。

GAN的主要靈感來源於博弈論中零和博弈的思想,應用到深度學習神經網路上來說,就是通過生成網路G(Generator)和判別網路D(Discriminator)不斷博弈,進而使G學習到資料的分佈,

GAN網路最強大的地方就是可以幫助我們建立模型,而不像傳統的網路那樣是在已有模型上幫我們更新引數而已。同時,因為GAN網路是一種無監督的學習方式,它的泛化性非常好。


1 GAN模型

1.1網路結構

上圖都描述了GAN的核心網路,在生成網路中,得到假的資料,然後和真的資料一起喂入判別模型,判別模型判斷輸入的樣本是真是假,先訓練識別網路,再訓練生成網路,再訓練識別網路,如此反覆,直到平衡。

1.2具體過程

  1. 生成模型:比作是一個樣本生成器,輸入一個噪聲/樣本,然後把它包裝成一個逼真的樣子,也就是輸出。

    • 生成網路是造樣本,它的目的就是使得自己造樣本的能力盡可能強,強到什麼程度呢,判別網路沒法判斷我是真樣本還是假樣本。

    • 通常這個網路選用最普通的多層隨機網路即可,網路太深容易引起梯度消失或者梯度爆炸。

  2. 判別模型:比作一個二分類器(如同0-1分類器),來判斷輸入的樣本是真是假。(就是輸出值大於0.5還是小於0.5)

    • 判別出來屬於的一張圖它是來自真實樣本集還是假樣本集。若輸入的是真樣本,輸出就接近1,輸出的是假樣本,輸出接近0。

訓練過程中,生成網路G的目標就是儘量生成真實的圖片去欺騙判別網路D。而D的目標就是儘量辨別出G生成的假影象和真實的影象。這樣,G和D構成了一個動態的“博弈過程”,最終的平衡點即納什均衡點.。

納什均衡是指博弈中這樣的局面,對於每個參與者來說,只要其他人不改變策略,他就無法改善自己的狀況。


上圖是是論文中的一張過程圖,判別分佈(藍色,虛線) ,生成資料的實際分佈(黑色,虛線),資料的生成分佈(綠色,實線)

(a) 對於D(判別網路)剛開始訓練,有波動,但基本可以區分實際資料和生成資料;

(b) 隨著訓練的進行,D可以明顯的區分實際資料和生成資料;

(c) 隨著G的更新,綠色的線能夠趨近於黑色的線;

(d) 經過幾步訓練,如果G和D有足夠的能力,他們將達到平衡,辨別器無法區分兩個分佈,即D(x)= 1;

1.3訓練結果

最終,訓練結束後,生成模型 G 恢復了訓練資料的分佈(造出了和真實資料一模一樣的樣本),判別模型再也判別不出來結果,準確率為 50%,約等於亂猜。這是雙方網路都得到利益最大化,不再改變自己的策略,也就是不再更新自己的權重。如果loss值很低,則生成器成功欺騙了識別器(把假資料當成和label一樣也是1了),如果loss很大(label儘管是1,但是識別器還是預測為0,識別器判斷出了真假),說明生成器還需提升)。


3 程式碼實現

  1. 避免使用RELU和pooling層,減少稀疏梯度的可能性,使用leakrelu啟用函式;
  2. 最後一層的啟用函式使用tanh;
  3. 在鑑別器中使用dropout;

3.1 Generative model:

model = Sequential()

model.add(Dense(256, input_dim=self.latent_dim))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(512))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(1024))
model.add(LeakyReLU(alpha=0.2))
model.add(BatchNormalization(momentum=0.8))
model.add(Dense(np.prod(self.img_shape), activation='tanh'))
model.add(Reshape(self.img_shape))
model.summary()

noise = Input(shape=(self.latent_dim,))
img = model(noise)

3.2 Discriminator model:

model = Sequential()     model.add(Flatten(input_shape=self.img_shape))     model.add(Dense(512))     
model.add(LeakyReLU(alpha=0.2))     
model.add(Dense(256))     
model.add(LeakyReLU(alpha=0.2))     
model.add(Dense(1, activation='sigmoid'))     
model.summary()    

img = Input(shape=self.img_shape)     
validity = model(img)

3.3 GAN

discriminator.trainable = False
gan_input = keras.Input(shape=(latent_dim,))
gan_output = discriminator(generator(gan_input))
gan = keras.models.Model(gan_input, gan_output)
gan_optimizer = keras.optimizers.RMSprop(lr=4e-4, clipvalue=1.0, decay=1e-8)
gan.compile(optimizer=gan_optimizer, loss='binary_crossentropy')

4 參考資料

https://www.jianshu.com/p/998cf8e52209

https://zhuanlan.zhihu.com/p/34287744