1. 程式人生 > 其它 >VAE(變分自編碼器的torch實現) —— jupyter實現(注意tqdm模組不同)

VAE(變分自編碼器的torch實現) —— jupyter實現(注意tqdm模組不同)

簡單實現了torch版本的變分自編碼器

參考大佬TensorFlow版本的VAE:膜拜大佬

import os
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
import torch
from torchvision import datasets, transforms
import torch.nn as nn
from time import sleep
from tqdm.notebook import tqdm
class CFG:
    batch_size = 10
    z_dim = 10
    epoch = 1000
    lr = 0.0001
mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= CFG.batch_size, shuffle=True)
for i in train_loader:
    print(i[0].shape)
    break
class VAE(nn.Module):
    def __init__(self):
        super().__init__()
        self.e1 = nn.Linear(784, 128)
        self.e2 = nn.Linear(128, CFG.z_dim)
        self.e3 = nn.Linear(128, CFG.z_dim)
        
        self.fc4 = nn.Linear(10, 128)
        self.fc5 = nn.Linear(128, 784)
    
    def reparameterize(self, mean, log_var):
        eps = torch.randn(log_var.shape)
        std = torch.exp(log_var)**0.5
        z = mean + eps * std
        return z
    
    def encoder(self, inputs):
        h = self.e1(inputs)
        h = torch.nn.ReLU()(h)
        mean = self.e2(h)
        log_var = self.e3(h)
        return mean, log_var
    
    def decoder(self, z):
        return self.fc5(torch.nn.ReLU()(self.fc4(z)))
    
    def forward(self, inputs):
        mean, log_var = self.encoder(inputs)
        z = self.reparameterize(mean, log_var)
        x_hat = self.decoder(z)
        x_hat = torch.sigmoid(x_hat)
        return x_hat, mean, log_var
model = VAE()
model.train()
cross_entroy_loss = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr)
for epoch in range(1, CFG.epoch + 1):
    loop = tqdm((train_loader), total = len(train_loader))
    for x in loop:
        x_, y_ = x[0].reshape(-1, 784), x[1]
        optimizer.zero_grad()
        x_rec_logits, mean, log_var = model(x_)
        rec_loss = cross_entroy_loss(x_, x_rec_logits)
        rec_loss = torch.mean(rec_loss)
        kl_div = -0.5 * (log_var + 1 - mean ** 2 - torch.exp(log_var))
        kl_div = torch.mean(kl_div) / x_.shape[0]
        
        loss = rec_loss + 1.0 * kl_div
        loss.backward()
        optimizer.step()
        
        loop.set_description(f'Epoch [{epoch}/{CFG.epoch}]')
        loop.set_postfix(loss=loss.item(), Kl_div = kl_div.item(),rec_loss = rec_loss.item())
        sleep(0.05)