1. 程式人生 > 實用技巧 >【論文閱讀筆記】《Conditional Generative Adversarial Nets》

【論文閱讀筆記】《Conditional Generative Adversarial Nets》

論文:《Conditional Generative Adversarial Nets》

年份:2014年

引言

原始的GAN過於自由,訓練會很容易失去方向,導致不穩定且效果差。比如說GAN生成MNIST數字的過程,雖然可以生成數字,但生成的結果是隨機的(因為是根據輸入的隨機噪聲生成的圖片),沒有辦法控制模型生成的具體數字。

CGAN就是在原來的GAN模型中加入一些先驗條件,使得GAN變得更加可控制。具體來說,我們可以在生成模型G和判別模型D中同時加入條件約束y來引導資料的生成過程。條件可以是任何補充的資訊,如類標籤等,這樣我們在生成新的樣本的同時,還能確切地控制新樣本的型別。

cGAN結構

cGAN的全程是Conditional Generative Adversarial Networks,即條件對抗生成網路。它為生成器、判別器都額外加入了一個條件y,這個條件實際上是希望生成的標籤。

生成器G必須要生成和條件y匹配的樣本,判別器不僅要判別影象是否真實,還要判別影象和條件y是否匹配。cGAN的輸入輸出為:

  • 生成器G:輸入一個噪聲z,一個條件y,輸出符合該條件的影象G。
  • 判別器D:輸入一張影象x,一個條件y,輸出該影象在該條件下的真實概率D(x|y)

優化目標

在原始的GAN中,優化目標為:

在cGAN中,在其中加入條件y,則優化目標修改成了:

以MNIST為例,生成器G和判別器D的輸入輸出是:

  • G輸入一個噪聲z,一個數字標籤y(y的取值範圍是0~9)。輸出和數字標籤相符合的影象G(z|y)。
  • D輸入一個影象x,一個數字標籤y。輸出影象和數字符合的概率D(x|y)。

顯然,在訓練完成後,向G輸入某個數字標籤和噪聲,可以生成對應數字的影象。

Pytorch程式碼實現

cGAN生成器

定義生成器及前向傳播函式:

class Generator(nn.Module):
  def __init__(self):
    super().__init__()
    self.label_emb = nn.Embedding(10, 10)
    self.model = nn.Sequential(
      nn.Linear(110, 256),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(256, 512),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(512, 1024),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Linear(1024, 784),
      nn.Tanh()
    )
  def forward(self, z, labels):
    z = z.view(z.size(0), 100)
    c = self.label_emb(labels)
    x = torch.cat([z, c], 1)
    out = self.model(x)
    return out.view(x.size(0), 28, 28)

其中,torch.nn.Embedding的函式介紹如下:

nn.Embedding(num_embeddings, embedding_dim)
"""
params:
- num_embeddings - 詞嵌入字典大小,即一個字典裡要有多少個詞。
- embedding_dim - 每個詞嵌入向量的大小。
"""

cGAN判別器

定義判別器及前向傳播函式:

class Discriminator(nn.Module):
  def __init__(self):
    super().__init__()
    self.label_emb = nn.Embedding(10, 10)
    self.model = nn.Sequential(
      nn.Linear(794, 1024),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Dropout(0.4),
      nn.Linear(1024, 512),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Dropout(0.4),
      nn.Linear(512, 256),
      nn.LeakyReLU(0.2, inplace=True),
      nn.Dropout(0.4),
      nn.Linear(256, 1),
      nn.Sigmoid()
    )
    def forward(self, x, labels):
      x = x.view(x.size(0), 784)
      c = self.label_emb(labels)
      x = torch.cat([x, c], 1)
      out = self.model(x)
      return out.squeeze()

cGAN損失函式

定義判別器對真、假影象的損失函式:

# 定義判別器對真影象的損失函式:
real_validity = D(images, labels)
d_loss_real = criterion(real_validity, real_labels)

# 定義判別器對假影象(即由潛在空間點生成的影象)的損失函式
z = torch.randn(batch_size, 100).to(device)
fake_labels = torch.randint(0,10,(batch_size,)).to(device)
fake_images = G(z, fake_labels)
fake_validity = D(fake_images, fake_labels)
d_loss_fake = criterion(fake_validity, torch.zeros(batch_size).to(device))

#CGAN總的損失值
d_loss = d_loss_real + d_loss_fake

cGAN視覺化

利用網格(10×10)的形式顯示指定條件下生成的影象

from torchvision.utils import make_grid
z = torch.randn(100, 100).to(device)
labels = torch.LongTensor([i for i in range(10) for _ in range(10)]).to(device)
images = G(z, labels).unsqueeze(1)
grid = make_grid(images, nrow=10, normalize=True)
fig, ax = plt.subplots(figsize=(10,10))
ax.imshow(grid.permute(1, 2, 0).detach().cpu().numpy(), cmap='binary')
ax.axis('off')

檢視指定標籤資料

視覺化指定單個數字條件下生成的數字:

def generate_digit(generator, digit):
  z = torch.randn(1, 100).to(device)
    
  label = torch.LongTensor([digit]).to(device)
    
  img = generator(z, label).detach().cpu()
    
  img = 0.5 * img + 0.5
    
  return transforms.ToPILImage()(img)

# 呼叫
generate_digit(G, 8)

視覺化損失值

記錄判別器和生成器的損失變化:

writer.add_scalars('scalars', {'g_loss': g_loss, 'd_loss': d_loss}, step)