1. 程式人生 > 程式設計 >pytorch:實現簡單的GAN示例(MNIST資料集)

pytorch:實現簡單的GAN示例(MNIST資料集)

我就廢話不多說了,直接上程式碼吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader,sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0,8.0) # 設定畫圖的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定義畫圖工具
  images = np.reshape(images,[images.shape[0],-1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn,sqrtn))
  gs = gridspec.GridSpec(sqrtn,sqrtn)
  gs.update(wspace=0.05,hspace=0.05)
 
  for i,img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定義一個取樣的函式
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self,num_samples,start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start,self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data',train=True,transform=preprocess_img)
 
train_data = DataLoader(train_set,batch_size=batch_size,sampler=ChunkSampler(NUM_TRAIN,0))
 
val_set = MNIST('E:/data',transform=preprocess_img)
 
val_data = DataLoader(val_set,sampler=ChunkSampler(NUM_VAL,NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size,784)).numpy().squeeze() # 視覺化圖片效果
show_images(imgs)
 
#判別網路
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784,256),nn.LeakyReLU(0.2),nn.Linear(256,1)
    )
  return net
  
#生成網路
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim,1024),nn.ReLU(True),nn.Linear(1024,784),nn.Tanh()
  )
  return net
  
#判別器的 loss 就是將真實資料的得分判斷為 1,假的資料的得分判斷為 0,而生成器的 loss 就是將假的資料判斷為 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵損失函式
 
def discriminator_loss(logits_real,logits_fake): # 判別器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size,1)).float()
  false_labels = Variable(torch.zeros(size,1)).float()
  loss = bce_loss(logits_real,true_labels) + bce_loss(logits_fake,false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size,1)).float()
  loss = bce_loss(logits_fake,true_labels)
  return loss
  
# 使用 adam 來進行訓練,學習率是 3e-4,beta1 是 0.5,beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(),lr=3e-4,betas=(0.5,0.999))
  return optimizer
  
def train_a_gan(D_net,G_net,D_optimizer,G_optimizer,discriminator_loss,generator_loss,show_every=250,noise_size=96,num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x,_ in train_data:
      bs = x.shape[0]
      # 判別網路
      real_data = Variable(x).view(bs,-1) # 真實資料
      logits_real = D_net(real_data) # 判別網路得分
      
      sample_noise = (torch.rand(bs,noise_size) - 0.5) / 0.5 # -1 ~ 1 的均勻分佈
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的資料
      logits_fake = D_net(fake_images) # 判別網路得分
 
      d_total_error = discriminator_loss(logits_real,logits_fake) # 判別器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 優化判別網路
      
      # 生成網路
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的資料
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成網路的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 優化生成網路
 
      if (iter_count % show_every == 0):
        print('Iter: {},D: {:.4},G:{:.4}'.format(iter_count,d_total_error.item(),g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D,G,D_optim,G_optim,generator_loss)      

以上這篇pytorch:實現簡單的GAN示例(MNIST資料集)就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。