pytorch GAN偽造手寫體mnist資料集方式
阿新 • • 發佈:2020-01-10
一,mnist資料集
形如上圖的數字手寫體就是mnist資料集。
二,GAN原理(生成對抗網路)
GAN網路一共由兩部分組成:一個是偽造器(Generator,簡稱G),一個是判別器(Discrimniator,簡稱D)
一開始,G由服從某幾個分佈(如高斯分佈)的噪音組成,生成的圖片不斷送給D判斷是否正確,直到G生成的圖片連D都判斷以為是真的。D每一輪除了看過G生成的假圖片以外,還要見資料集中的真圖片,以前者和後者得到的損失函式值為依據更新D網路中的權值。因此G和D都在不停地更新權值。以下圖為例:
在v1時的G只不過是 一堆噪聲,見過資料集(real images)的D肯定能判斷出G所生成的是假的。當然G也能知道D判斷它是假的這個結果,因此G就會更新權值,到v2的時候,G就能生成更逼真的圖片來讓D判斷,當然在v2時D也是會先看一次真圖片,再去判斷G所生成的圖片。以此類推,不斷迴圈就是GAN的思想。
三,訓練程式碼
import argparse import os import numpy as np import math import torchvision.transforms as transforms from torchvision.utils import save_image from torch.utils.data import DataLoader from torchvision import datasets from torch.autograd import Variable import torch.nn as nn import torch.nn.functional as F import torch os.makedirs("images",exist_ok=True) parser = argparse.ArgumentParser() parser.add_argument("--n_epochs",type=int,default=200,help="number of epochs of training") parser.add_argument("--batch_size",default=64,help="size of the batches") parser.add_argument("--lr",type=float,default=0.0002,help="adam: learning rate") parser.add_argument("--b1",default=0.5,help="adam: decay of first order momentum of gradient") parser.add_argument("--b2",default=0.999,help="adam: decay of first order momentum of gradient") parser.add_argument("--n_cpu",default=8,help="number of cpu threads to use during batch generation") parser.add_argument("--latent_dim",default=100,help="dimensionality of the latent space") parser.add_argument("--img_size",default=28,help="size of each image dimension") parser.add_argument("--channels",default=1,help="number of image channels") parser.add_argument("--sample_interval",default=400,help="interval betwen image samples") opt = parser.parse_args() print(opt) img_shape = (opt.channels,opt.img_size,opt.img_size) # 確定圖片輸入的格式為(1,28,28),由於mnist資料集是灰度圖所以通道為1 cuda = True if torch.cuda.is_available() else False class Generator(nn.Module): def __init__(self): super(Generator,self).__init__() def block(in_feat,out_feat,normalize=True): layers = [nn.Linear(in_feat,out_feat)] if normalize: layers.append(nn.BatchNorm1d(out_feat,0.8)) layers.append(nn.LeakyReLU(0.2,inplace=True)) return layers self.model = nn.Sequential( *block(opt.latent_dim,128,normalize=False),*block(128,256),*block(256,512),*block(512,1024),nn.Linear(1024,int(np.prod(img_shape))),nn.Tanh() ) def forward(self,z): img = self.model(z) img = img.view(img.size(0),*img_shape) return img class Discriminator(nn.Module): def __init__(self): super(Discriminator,self).__init__() self.model = nn.Sequential( nn.Linear(int(np.prod(img_shape)),nn.LeakyReLU(0.2,inplace=True),nn.Linear(512,nn.Linear(256,1),nn.Sigmoid(),) def forward(self,img): img_flat = img.view(img.size(0),-1) validity = self.model(img_flat) return validity # Loss function adversarial_loss = torch.nn.BCELoss() # Initialize generator and discriminator generator = Generator() discriminator = Discriminator() if cuda: generator.cuda() discriminator.cuda() adversarial_loss.cuda() # Configure data loader os.makedirs("../../data/mnist",exist_ok=True) dataloader = torch.utils.data.DataLoader( datasets.MNIST( "../../data/mnist",train=True,download=True,transform=transforms.Compose( [transforms.Resize(opt.img_size),transforms.ToTensor(),transforms.Normalize([0.5],[0.5])] ),),batch_size=opt.batch_size,shuffle=True,) # Optimizers optimizer_G = torch.optim.Adam(generator.parameters(),lr=opt.lr,betas=(opt.b1,opt.b2)) optimizer_D = torch.optim.Adam(discriminator.parameters(),opt.b2)) Tensor = torch.cuda.FloatTensor if cuda else torch.FloatTensor # ---------- # Training # ---------- if __name__ == '__main__': for epoch in range(opt.n_epochs): for i,(imgs,_) in enumerate(dataloader): # print(imgs.shape) # Adversarial ground truths valid = Variable(Tensor(imgs.size(0),1).fill_(1.0),requires_grad=False) # 全1 fake = Variable(Tensor(imgs.size(0),1).fill_(0.0),requires_grad=False) # 全0 # Configure input real_imgs = Variable(imgs.type(Tensor)) # ----------------- # Train Generator # ----------------- optimizer_G.zero_grad() # 清空G網路 上一個batch的梯度 # Sample noise as generator input z = Variable(Tensor(np.random.normal(0,1,(imgs.shape[0],opt.latent_dim)))) # 生成的噪音,均值為0方差為1維度為(64,100)的噪音 # Generate a batch of images gen_imgs = generator(z) # Loss measures generator's ability to fool the discriminator g_loss = adversarial_loss(discriminator(gen_imgs),valid) g_loss.backward() # g_loss用於更新G網路的權值,g_loss於D網路的判斷結果 有關 optimizer_G.step() # --------------------- # Train Discriminator # --------------------- optimizer_D.zero_grad() # 清空D網路 上一個batch的梯度 # Measure discriminator's ability to classify real from generated samples real_loss = adversarial_loss(discriminator(real_imgs),valid) fake_loss = adversarial_loss(discriminator(gen_imgs.detach()),fake) d_loss = (real_loss + fake_loss) / 2 d_loss.backward() # d_loss用於更新D網路的權值 optimizer_D.step() print( "[Epoch %d/%d] [Batch %d/%d] [D loss: %f] [G loss: %f]" % (epoch,opt.n_epochs,i,len(dataloader),d_loss.item(),g_loss.item()) ) batches_done = epoch * len(dataloader) + i if batches_done % opt.sample_interval == 0: save_image(gen_imgs.data[:25],"images/%d.png" % batches_done,nrow=5,normalize=True) # 儲存一個batchsize中的25張 if (epoch+1) %2 ==0: print('save..') torch.save(generator,'g%d.pth' % epoch) torch.save(discriminator,'d%d.pth' % epoch)
執行結果:
一開始時,G生成的全是雜音:
然後逐漸呈現數字的雛形:
最後一次生成的結果:
四,測試程式碼:
匯入最後儲存生成器的模型:
from gan import Generator,Discriminator import torch import matplotlib.pyplot as plt from torch.autograd import Variable import numpy as np from torchvision.utils import save_image device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') Tensor = torch.cuda.FloatTensor g = torch.load('g199.pth') #匯入生成器Generator模型 #d = torch.load('d.pth') g = g.to(device) #d = d.to(device) z = Variable(Tensor(np.random.normal(0,(64,100)))) #輸入的噪音 gen_imgs =g(z) #生產圖片 save_image(gen_imgs.data[:25],"images.png",normalize=True)
生成結果:
以上這篇pytorch GAN偽造手寫體mnist資料集方式就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。