pytorch 實現變分自動編碼器
阿新 • • 發佈:2018-11-09
本來以為自動編碼器是很簡單的東西,但是也是看了好多資料仍然不太懂它的原理。先把程式碼記錄下來,有時間好好研究。
這個例子是用MNIST資料集生成為例子。
# -*- coding: utf-8 -*- """ Created on Fri Oct 12 11:42:19 2018 @author: www """ import os import torch from torch.autograd import Variable import torch.nn.functional as F from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision import transforms as tfs from torchvision.utils import save_image im_tfs = tfs.Compose([ tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 標準化 ]) train_set = MNIST('E:\data', transform=im_tfs) train_data = DataLoader(train_set, batch_size=128, shuffle=True) class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) # mean self.fc22 = nn.Linear(400, 20) # var self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparametrize(self, mu, logvar): std = logvar.mul(0.5).exp_() eps = torch.FloatTensor(std.size()).normal_() if torch.cuda.is_available(): eps = Variable(eps.cuda()) else: eps = Variable(eps) return eps.mul(std).add_(mu) def decode(self, z): h3 = F.relu(self.fc3(z)) return F.tanh(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x) # 編碼 z = self.reparametrize(mu, logvar) # 重新引數化成正態分佈 return self.decode(z), mu, logvar # 解碼,同時輸出均值方差 net = VAE() # 例項化網路 if torch.cuda.is_available(): net = net.cuda() x, _ = train_set[0] x = x.view(x.shape[0], -1) if torch.cuda.is_available(): x = x.cuda() x = Variable(x) _, mu, var = net(x) print(mu) #可以看到,對於輸入,網路可以輸出隱含變數的均值和方差,這裡的均值方差還沒有訓練 #下面開始訓練 reconstruction_function = nn.MSELoss(size_average=False) def loss_function(recon_x, x, mu, logvar): """ recon_x: generating images x: origin images mu: latent mean logvar: latent log variance """ MSE = reconstruction_function(recon_x, x) # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.sum(KLD_element).mul_(-0.5) # KL divergence return MSE + KLD optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) def to_img(x): ''' 定義一個函式將最後的結果轉換回圖片 ''' x = 0.5 * (x + 1.) x = x.clamp(0, 1) x = x.view(x.shape[0], 1, 28, 28) return x for e in range(100): for im, _ in train_data: im = im.view(im.shape[0], -1) im = Variable(im) if torch.cuda.is_available(): im = im.cuda() recon_im, mu, logvar = net(im) loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 將 loss 平均 optimizer.zero_grad() loss.backward() optimizer.step() if (e + 1) % 20 == 0: print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item())) save = to_img(recon_im.cpu().data) if not os.path.exists('./vae_img'): os.mkdir('./vae_img') save_image(save, './vae_img/image_{}.png'.format(e + 1))