[深度學習]半監督學習、無監督學習之Variational Auto-Encoder變分自編碼器(附程式碼)
阿新 • • 發佈:2018-12-24
論文全稱:《Auto-Encoding Variational Bayes》
論文地址:https://arxiv.org/pdf/1312.6114.pdf
論文程式碼:
keras 版本:https://github.com/bojone/vae
關於VAE的部落格教程網路上有很多,但是沒有幾個是能夠講得清晰明瞭的,而且能夠與程式碼結合更是少之又少。
“Talk is cheap,show me the code”
這裡推薦一個博主寫的挺不錯的VAE分析 :https://spaces.ac.cn/archives/5253
基本上上面這篇可以與下面的程式碼一致,而我覺得自己沒有必要再去解釋很多公式的VAE。
首先匯入包和設定超引數:
import os import torch import torch.nn as nn import torch.nn.functional as F # 2-d latent space, parameter count in same order of magnitude # as in the original VAE paper (VAE paper has about 3x as many) latent_dims = 2 num_epochs = 100 batch_size = 128 capacity = 64 learning_rate = 1e-3 variational_beta = 1 use_gpu = True
載入MINIST資料集:
import torchvision.transforms as transforms from torch.utils.data import DataLoader from torchvision.datasets import MNIST img_transform = transforms.Compose([ transforms.ToTensor() ]) train_dataset = MNIST(root='./data/MNIST', download=True, train=True, transform=img_transform) train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True) test_dataset = MNIST(root='./data/MNIST', download=True, train=False, transform=img_transform) test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)
定義VAE的結構,整體結構和Autoencoder很類似,但encoder學習的是隱變數的均值和方差,然後再根據他們生成隱變數。注意在encoder中分別有兩個全連線層,對應於均值和方差。
而latent_sample函式對應於reparameterization tricks。從N(0,I)中取樣一ε,然後讓Z=μ+ε×σ。
class Encoder(nn.Module):
def __init__(self):
super(Encoder, self).__init__()
c = capacity
self.conv1 = nn.Conv2d(in_channels=1, out_channels=c, kernel_size=4, stride=2, padding=1) # out: c x 14 x 14
self.conv2 = nn.Conv2d(in_channels=c, out_channels=c*2, kernel_size=4, stride=2, padding=1) # out: c x 7 x 7
self.fc_mu = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)
self.fc_logvar = nn.Linear(in_features=c*2*7*7, out_features=latent_dims)
def forward(self, x):
x = F.relu(self.conv1(x))
x = F.relu(self.conv2(x))
x = x.view(x.size(0), -1) # flatten batch of multi-channel feature maps to a batch of feature vectors
x_mu = self.fc_mu(x)
x_logvar = self.fc_logvar(x)
return x_mu, x_logvar
class Decoder(nn.Module):
def __init__(self):
super(Decoder, self).__init__()
c = capacity
self.fc = nn.Linear(in_features=latent_dims, out_features=c*2*7*7)
self.conv2 = nn.ConvTranspose2d(in_channels=c*2, out_channels=c, kernel_size=4, stride=2, padding=1)
self.conv1 = nn.ConvTranspose2d(in_channels=c, out_channels=1, kernel_size=4, stride=2, padding=1)
def forward(self, x):
x = self.fc(x)
x = x.view(x.size(0), capacity*2, 7, 7) # unflatten batch of feature vectors to a batch of multi-channel feature maps
x = F.relu(self.conv2(x))
x = torch.sigmoid(self.conv1(x)) # last layer before output is sigmoid, since we are using BCE as reconstruction loss
return x
class VariationalAutoencoder(nn.Module):
def __init__(self):
super(VariationalAutoencoder, self).__init__()
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, x):
latent_mu, latent_logvar = self.encoder(x)
latent = self.latent_sample(latent_mu, latent_logvar)
x_recon = self.decoder(latent)
return x_recon, latent_mu, latent_logvar
def latent_sample(self, mu, logvar):
if self.training:
# the reparameterization trick
std = logvar.mul(0.5).exp_()
eps = torch.empty_like(std).normal_()
return eps.mul(std).add_(mu)
else:
return mu
定義loss函式,第一部分是關於輸入與輸出的相似程度,第二部分則是用KL散度來衡量學習到的隱變數空間和真實隱變數空間之間的相似性。
def vae_loss(recon_x, x, mu, logvar):
# recon_x is the probability of a multivariate Bernoulli distribution p.
# -log(p(x)) is then the pixel-wise binary cross-entropy.
# Averaging or not averaging the binary cross-entropy over all pixels here
# is a subtle detail with big effect on training, since it changes the weight
# we need to pick for the other loss term by several orders of magnitude.
# Not averaging is the direct implementation of the negative log likelihood,
# but averaging makes the weight of the other loss term independent of the image resolution.
recon_loss = F.binary_cross_entropy(recon_x.view(-1, 784), x.view(-1, 784), reduction='sum')
# KL-divergence between the prior distribution over latent vectors
# (the one we are going to sample from when generating new images)
# and the distribution estimated by the generator for the given image.
kldivergence = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp())
return recon_loss + variational_beta * kldivergence
用上面的結構初始化一個vae,看看有多少權重引數需要學習。
vae = VariationalAutoencoder()
device = torch.device("cuda:0" if use_gpu and torch.cuda.is_available() else "cpu")
vae = vae.to(device)
num_params = sum(p.numel() for p in vae.parameters() if p.requires_grad)
print('Number of parameters: %d' % num_params)
訓練vae
optimizer = torch.optim.Adam(params=vae.parameters(), lr=learning_rate, weight_decay=1e-5)
# set to training mode
vae.train()
train_loss_avg = []
print('Training ...')
for epoch in range(num_epochs):
train_loss_avg.append(0)
num_batches = 0
for image_batch, _ in train_dataloader:
image_batch = image_batch.to(device)
# vae reconstruction
image_batch_recon, latent_mu, latent_logvar = vae(image_batch)
# reconstruction error
loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)
# backpropagation
optimizer.zero_grad()
loss.backward()
# one step of the optmizer (using the gradients from backpropagation)
optimizer.step()
train_loss_avg[-1] += loss.item()
num_batches += 1
train_loss_avg[-1] /= num_batches
print('Epoch [%d / %d] average reconstruction error: %f' % (epoch+1, num_epochs, train_loss_avg[-1]))
描繪loss曲線:
import matplotlib.pyplot as plt
plt.ion()
fig = plt.figure()
plt.plot(train_loss_avg)
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.show()
如果不想從零訓練,可以載入預訓練模型。
filename = 'vae_2d.pth'
# filename = 'vae_10d.pth'
import urllib
if not os.path.isdir('./pretrained'):
os.makedirs('./pretrained')
print('downloading ...')
urllib.request.urlretrieve ("http://geometry.cs.ucl.ac.uk/creativeai/pretrained/"+filename, "./pretrained/"+filename)
vae.load_state_dict(torch.load('./pretrained/'+filename))
print('done')
# this is how the VAE parameters can be saved:
# torch.save(vae.state_dict(), './pretrained/my_vae.pth')
在測試集測試一下結果。
# set to evaluation mode
vae.eval()
test_loss_avg, num_batches = 0, 0
for image_batch, _ in test_dataloader:
with torch.no_grad():
image_batch = image_batch.to(device)
# vae reconstruction
image_batch_recon, latent_mu, latent_logvar = vae(image_batch)
# reconstruction error
loss = vae_loss(image_batch_recon, image_batch, latent_mu, latent_logvar)
test_loss_avg += loss.item()
num_batches += 1
test_loss_avg /= num_batches
print('average reconstruction error: %f' % (test_loss_avg))
視覺化結果
import numpy as np
import matplotlib.pyplot as plt
plt.ion()
import torchvision.utils
vae.eval()
# This function takes as an input the images to reconstruct
# and the name of the model with which the reconstructions
# are performed
def to_img(x):
x = x.clamp(0, 1)
return x
def show_image(img):
img = to_img(img)
npimg = img.numpy()
plt.imshow(np.transpose(npimg, (1, 2, 0)))
def visualise_output(images, model):
with torch.no_grad():
images = images.to(device)
images, _, _ = model(images)
images = images.cpu()
images = to_img(images)
np_imagegrid = torchvision.utils.make_grid(images[1:50], 10, 5).numpy()
plt.imshow(np.transpose(np_imagegrid, (1, 2, 0)))
plt.show()
images, labels = iter(test_dataloader).next()
# First visualise the original images
print('Original images')
show_image(torchvision.utils.make_grid(images[1:50],10,5))
plt.show()
# Reconstruct and visualise the images using the vae
print('VAE reconstruction:')
visualise_output(images, vae)
視覺化2d隱變數空間
# load a network that was trained with a 2d latent space
if latent_dims != 2:
print('Please change the parameters to two latent dimensions.')
with torch.no_grad():
# create a sample grid in 2d latent space
latent_x = np.linspace(-1.5,1.5,20)
latent_y = np.linspace(-1.5,1.5,20)
latents = torch.FloatTensor(len(latent_y), len(latent_x), 2)
for i, lx in enumerate(latent_x):
for j, ly in enumerate(latent_y):
latents[j, i, 0] = lx
latents[j, i, 1] = ly
latents = latents.view(-1, 2) # flatten grid into a batch
# reconstruct images from the latent vectors
latents = latents.to(device)
image_recon = vae.decoder(latents)
image_recon = image_recon.cpu()
fig, ax = plt.subplots(figsize=(10, 10))
show_image(torchvision.utils.make_grid(image_recon.data[:400],20,5))
plt.show()