1. 程式人生 > >Decoupled Learning for Conditional Adversarial Networks

Decoupled Learning for Conditional Adversarial Networks

文章提出裡在已有的ED+GAN的基礎上,添見一個生成網路,即ED//GAN,網路結構如下,

這裡寫圖片描述

上圖中左邊為傳統的GAN網路,Enc+Dec相當於生成網路,D為判別網路,構造GAN損失函式,以及生成圖片與輸入的重構誤差(L1損失函式,這種網路結構我們熟悉的有pix2pix,cyclegan.

上圖中右邊為本文提出的網路結構,即在ED+GAN的基礎上,新增一個生成網路,相當於有兩個生成網路.兩個生成網路的目的是,兩個生成網路分別可以學習影象的不同特徵,例如,一個生成網路用於生成低頻特徵,另一個用於生成高頻特徵,

這裡寫圖片描述

最後的結果為兩個生成網路的和.判別網路用於判斷最後生成的圖片,以及目標輸入圖片的真假.損失函式同樣為GAN loss,以及生成網路Enc+Dec網路生成圖片與輸入圖片的重構誤差,也就是希望生成圖片的低頻特徵與輸入圖片儘量相似.

這裡寫圖片描述

作者提供了github程式碼:https://github.com/ZZUTK/Decoupled-Learning-Conditional-GAN

程式碼包括在pix2pix,CAAE模型上的EN//GAN的結構,下面我以pix2pix的EN//GAN模型為例,分析程式碼.

程式碼中,Enc+Dec,生成網路generator,生成網路G,generator_p,的結構都與pix2pix的生成網路結構相同,

將輸入圖片輸入兩個網路,

self.const_B = self.generator(self.real_A)
self.res_B = self.generator_p(self
.real_A)

將兩者相加,得到最後的生成圖片,

self.fake_B = self.const_B + self.res_B

與pix2pix一樣,將生成圖片,目標圖片分別與輸入圖片串聯,輸入判別網路,判別網路結構也與pix2pix判別網路相同,

self.real_AB = tf.concat( [self.real_A, self.real_B],3)
self.fake_AB = tf.concat([self.real_A, self.fake_B],3)
self.D, self.D_logits = self.discriminator(self.real_AB, reuse=False
) self.D_, self.D_logits_ = self.discriminator(self.fake_AB, reuse=True)

判別網路損失函式為GAN loss,

self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D)))
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_)))
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_)))

生成網路損失函式,重構誤差函式為Enc+Dec輸出與輸入圖片的重構誤差,以及GAN loss,

self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_)))
self.const_loss = tf.reduce_mean(tf.abs(self.real_B - self.const_B))

生成效果對比,

這裡寫圖片描述