1. 程式人生 > 其它 >用pytorch搭建生成對抗網路(GAN)

用pytorch搭建生成對抗網路(GAN)

技術標籤:深度學習pytorch深度學習神經網路pytorch

生成對抗網路是2014提出的網路模型,具有很好的影象生成效果,藉助最近在學習的pytorch搭建一個GAN網路,該網路採用全連線實現。

import torch
from torch import nn
import torchvision.transforms as tfs
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt


def preprocess_img
(x): x = tfs.ToTensor()(x) # x (0., 1.) return (x - 0.5) / 0.5 # x (-1., 1.) def deprocess_img(x): # x (-1., 1.) return (x + 1.0) / 2.0 # x (0., 1.) def discriminator(): net = nn.Sequential( nn.Linear(784, 256), nn.LeakyReLU(0.2), nn.Linear(256, 256), nn.
LeakyReLU(0.2), nn.Linear(256, 1), ) return net def generator(noise_dim): net = nn.Sequential( nn.Linear(noise_dim, 1024), nn.ReLU(True), nn.Linear(1024, 1024), nn.ReLU(True), nn.Linear(1024, 784), nn.Tanh(), ) return net def discriminator_loss
(logits_real, logits_fake): # 判別器的loss size = logits_real.shape[0] true_labels = torch.ones(size, 1).float() false_labels = torch.zeros(size, 1).float() bce_loss = nn.BCEWithLogitsLoss() loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels) return loss def generator_loss(logits_fake): # 生成器的 loss size = logits_fake.shape[0] true_labels = torch.ones(size, 1).float() bce_loss = nn.BCEWithLogitsLoss() loss = bce_loss(logits_fake, true_labels) # 假圖與真圖的誤差。訓練的目的是減小誤差,即讓假圖接近真圖。 return loss # 使用 adam 來進行訓練,beta1 是 0.5, beta2 是 0.999 def get_optimizer(net, LearningRate): optimizer = torch.optim.Adam(net.parameters(), lr=LearningRate, betas=(0.5, 0.999)) return optimizer def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, noise_size, num_epochs, num_img): f, a = plt.subplots(num_img, num_img, figsize=(num_img, num_img)) plt.ion() # Turn the interactive mode on, continuously plot for epoch in range(num_epochs): for iteration, (x, _) in enumerate(train_data): bs = x.shape[0] # 訓練判別網路 real_data = x.view(bs, -1) # 真實資料 logits_real = D_net(real_data) # 判別網路得分 rand_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均勻分佈 fake_images = G_net(rand_noise) # 生成的假的資料 logits_fake = D_net(fake_images) # 判別網路得分 d_total_error = discriminator_loss(logits_real, logits_fake) # 判別器的 loss D_optimizer.zero_grad() d_total_error.backward() D_optimizer.step() # 優化判別網路 # 訓練生成網路 rand_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均勻分佈 fake_images = G_net(rand_noise) # 生成的假的資料 gen_logits_fake = D_net(fake_images) g_error = generator_loss(gen_logits_fake) # 生成網路的 loss G_optimizer.zero_grad() g_error.backward() G_optimizer.step() # 優化生成網路 if iteration % 20 == 0: print('Epoch: {:2d} | Iter: {:<4d} | D: {:.4f} | G:{:.4f}'.format(epoch, iteration, d_total_error.data.numpy(), g_error.data.numpy())) imgs_numpy = deprocess_img(fake_images.data.cpu().numpy()) for i in range(num_img ** 2): a[i // num_img][i % num_img].imshow(np.reshape(imgs_numpy[i], (28, 28)), cmap='gray') a[i // num_img][i % num_img].set_xticks(()) a[i // num_img][i % num_img].set_yticks(()) plt.suptitle('epoch: {} iteration: {}'.format(epoch, iteration)) plt.pause(0.01) plt.ioff() plt.show() if __name__ == '__main__': #週期 EPOCH = 5 #每次送入網路的圖片數量 BATCH_SIZE = 128 #學習率 LR = 5e-4 #噪聲的長度 NOISE_DIM = 96 #顯示的影象數量 NUM_IMAGE = 4 # for showing images when training #獲取MNIST資料集,如果沒有就重新下載 train_set = MNIST(root='./data/mnist/', train=True, download=True, transform=preprocess_img) #載入資料 train_data = DataLoader(train_set, batch_size=BATCH_SIZE, shuffle=True) D = discriminator() G = generator(NOISE_DIM) D_optim = get_optimizer(D, LR) G_optim = get_optimizer(G, LR) train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss, NOISE_DIM, EPOCH, NUM_IMAGE)