基於MNIST的GANs實現【Pytorch】
阿新 • • 發佈:2018-11-15
簡述
其實是根據我之前寫的兩個程式碼改的。(之前已經有過非常詳細的解釋了,可以去看看)
同時,在結合了我之前寫的DCGANs的時候,實現的一份程式碼
MNIST上選特定的數值,是根據下面的這篇文章得到的。
之前的程式碼上都有非常詳細的解釋。這裡只是基於上面的一點點改進而已。就不給出特別詳細的解釋。但是程式碼中任然保留有註釋部分。
圖形演變過程
程式碼
import torch
import torch.nn as nn
import torchvision
import torch.utils.data as Data
import matplotlib.pyplot as plt
import os
import shutil
import imageio
PNGFILE = './png/'
if not os.path.exists(PNGFILE):
os.mkdir(PNGFILE)
else:
shutil.rmtree(PNGFILE)
os.mkdir(PNGFILE)
# Hyper Parameters
BATCH_SIZE = 64
LR_G = 0.0001 # learning rate for generator
LR_D = 0.0001 # learning rate for discriminator
N_IDEAS = 100 # think of this as number of ideas for generating an art work (Generator)
target_num = 0 # target Number
EPOCH = 10 # 訓練整批資料多少次
DOWNLOAD_MNIST = False # 已經下載好的話,會自動跳過的
ART_COMPONENTS = 28 * 28
# Mnist 手寫數字
class myMNIST(torchvision.datasets.MNIST):
def __init__(self, root, train=True, transform=None, target_transform=None, download=False, targetNum=None):
super(myMNIST, self).__init__(
root,
train=train,
transform=transform,
target_transform=target_transform,
download=download)
if targetNum != None:
self.train_data = self.train_data[self.train_labels == targetNum]
self.train_data = self.train_data[:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]
self.train_labels = self.train_labels[self.train_labels == targetNum][
:int(self.__len__() / BATCH_SIZE) * BATCH_SIZE]
def __len__(self):
if self.train:
return self.train_data.shape[0]
else:
return 10000
train_data = myMNIST(
root='./mnist/', # 儲存或者提取位置
train=True, # this is training data
transform=torchvision.transforms.ToTensor(), # 轉換 PIL.Image or numpy.ndarray 成
# torch.FloatTensor (C x H x W), 訓練的時候 normalize 成 [0.0, 1.0] 區間
download=DOWNLOAD_MNIST, # 沒下載就下載, 下載了就不用再下了
targetNum=target_num
)
print(len(train_data))
# print(train_data.shape)
# 訓練集丟BATCH_SIZE個, 圖片大小為28*28
train_loader = Data.DataLoader(
dataset=train_data,
batch_size=BATCH_SIZE,
shuffle=True # 是否打亂順序
)
G = nn.Sequential( # Generator
nn.Linear(N_IDEAS, 128), # random ideas (could from normal distribution)
nn.ReLU(),
nn.Linear(128, ART_COMPONENTS), # making a painting from these random ideas
nn.ReLU(),
)
D = nn.Sequential( # Discriminator
nn.Linear(ART_COMPONENTS, 128), # receive art work either from the famous artist or a newbie like G
nn.ReLU(),
nn.Linear(128, 1),
nn.Sigmoid(), # tell the probability that the art work is made by artist
)
# loss & optimizer
optimD = torch.optim.Adam(D.parameters(), lr=LR_D)
optimG = torch.optim.Adam(G.parameters(), lr=LR_G)
label_Real = torch.FloatTensor(BATCH_SIZE).data.fill_(1)
label_Fake = torch.FloatTensor(BATCH_SIZE).data.fill_(0)
filePath = []
for epoch in range(EPOCH):
for step, (images, imagesLabel) in enumerate(train_loader):
G_ideas = torch.randn((BATCH_SIZE, N_IDEAS))
G_paintings = G(G_ideas)
images = images.reshape(BATCH_SIZE, -1)
prob_artist0 = D(images) # D try to increase this prob
prob_artist1 = D(G_paintings)
D_loss = - torch.mean(torch.log(prob_artist0) + torch.log(1. - prob_artist1))
G_loss = torch.mean(torch.log(1. - prob_artist1))
optimD.zero_grad()
D_loss.backward(retain_graph=True)
optimD.step()
optimG.zero_grad()
G_loss.backward(retain_graph=True)
optimG.step()
if step % 20 == 0:
plt.cla()
picture = torch.squeeze(G_paintings[0]).detach().numpy().reshape((28, 28))
plt.imshow(picture, cmap=plt.cm.gray_r)
plt.savefig(PNGFILE + '%d-%d.png' % (epoch, step))
filePath.append(PNGFILE + '%d-%d.png' % (epoch, step))
generated_images = []
for png_path in filePath:
generated_images.append(imageio.imread(png_path))
shutil.rmtree(PNGFILE)
imageio.mimsave('gan-mnist.gif', generated_images, 'GIF', duration=0.1)