Decoupled Learning for Conditional Adversarial Networks
阿新 • • 發佈:2018-11-06
文章提出裡在已有的ED+GAN的基礎上,添見一個生成網路,即ED//GAN,網路結構如下,
上圖中左邊為傳統的GAN網路,Enc+Dec相當於生成網路,D為判別網路,構造GAN損失函式,以及生成圖片與輸入的重構誤差(L1損失函式,這種網路結構我們熟悉的有pix2pix,cyclegan.
上圖中右邊為本文提出的網路結構,即在ED+GAN的基礎上,新增一個生成網路,相當於有兩個生成網路.兩個生成網路的目的是,兩個生成網路分別可以學習影象的不同特徵,例如,一個生成網路用於生成低頻特徵,另一個用於生成高頻特徵,
最後的結果為兩個生成網路的和.判別網路用於判斷最後生成的圖片,以及目標輸入圖片的真假.損失函式同樣為GAN loss,以及生成網路Enc+Dec網路生成圖片與輸入圖片的重構誤差,也就是希望生成圖片的低頻特徵與輸入圖片儘量相似.
作者提供了github程式碼:https://github.com/ZZUTK/Decoupled-Learning-Conditional-GAN
程式碼包括在pix2pix,CAAE模型上的EN//GAN的結構,下面我以pix2pix的EN//GAN模型為例,分析程式碼.
程式碼中,Enc+Dec,生成網路generator,生成網路G,generator_p,的結構都與pix2pix的生成網路結構相同,
將輸入圖片輸入兩個網路,
self.const_B = self.generator(self.real_A)
self.res_B = self.generator_p(self .real_A)
將兩者相加,得到最後的生成圖片,
self.fake_B = self.const_B + self.res_B
與pix2pix一樣,將生成圖片,目標圖片分別與輸入圖片串聯,輸入判別網路,判別網路結構也與pix2pix判別網路相同,
self.real_AB = tf.concat( [self.real_A, self.real_B],3)
self.fake_AB = tf.concat([self.real_A, self.fake_B],3)
self.D, self.D_logits = self.discriminator(self.real_AB, reuse=False )
self.D_, self.D_logits_ = self.discriminator(self.fake_AB, reuse=True)
判別網路損失函式為GAN loss,
self.d_loss_real = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits, labels=tf.ones_like(self.D)))
self.d_loss_fake = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.zeros_like(self.D_)))
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_)))
生成網路損失函式,重構誤差函式為Enc+Dec輸出與輸入圖片的重構誤差,以及GAN loss,
self.g_loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=self.D_logits_, labels=tf.ones_like(self.D_)))
self.const_loss = tf.reduce_mean(tf.abs(self.real_B - self.const_B))
生成效果對比,