VAE(變分自編碼器的torch實現) —— jupyter實現(注意tqdm模組不同)
阿新 • • 發佈:2022-12-06
簡單實現了torch版本的變分自編碼器
參考大佬TensorFlow版本的VAE:膜拜大佬
import os
import numpy as np
from PIL import Image
from matplotlib import pyplot as plt
import torch
from torchvision import datasets, transforms
import torch.nn as nn
from time import sleep
from tqdm.notebook import tqdm
class CFG: batch_size = 10 z_dim = 10 epoch = 1000 lr = 0.0001
mnist_train = datasets.MNIST("mnist-data", train=True, download=True, transform=transforms.ToTensor())
train_loader = torch.utils.data.DataLoader(mnist_train, batch_size= CFG.batch_size, shuffle=True)
for i in train_loader:
print(i[0].shape)
break
class VAE(nn.Module): def __init__(self): super().__init__() self.e1 = nn.Linear(784, 128) self.e2 = nn.Linear(128, CFG.z_dim) self.e3 = nn.Linear(128, CFG.z_dim) self.fc4 = nn.Linear(10, 128) self.fc5 = nn.Linear(128, 784) def reparameterize(self, mean, log_var): eps = torch.randn(log_var.shape) std = torch.exp(log_var)**0.5 z = mean + eps * std return z def encoder(self, inputs): h = self.e1(inputs) h = torch.nn.ReLU()(h) mean = self.e2(h) log_var = self.e3(h) return mean, log_var def decoder(self, z): return self.fc5(torch.nn.ReLU()(self.fc4(z))) def forward(self, inputs): mean, log_var = self.encoder(inputs) z = self.reparameterize(mean, log_var) x_hat = self.decoder(z) x_hat = torch.sigmoid(x_hat) return x_hat, mean, log_var
model = VAE() model.train() cross_entroy_loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam(model.parameters(), lr=CFG.lr) for epoch in range(1, CFG.epoch + 1): loop = tqdm((train_loader), total = len(train_loader)) for x in loop: x_, y_ = x[0].reshape(-1, 784), x[1] optimizer.zero_grad() x_rec_logits, mean, log_var = model(x_) rec_loss = cross_entroy_loss(x_, x_rec_logits) rec_loss = torch.mean(rec_loss) kl_div = -0.5 * (log_var + 1 - mean ** 2 - torch.exp(log_var)) kl_div = torch.mean(kl_div) / x_.shape[0] loss = rec_loss + 1.0 * kl_div loss.backward() optimizer.step() loop.set_description(f'Epoch [{epoch}/{CFG.epoch}]') loop.set_postfix(loss=loss.item(), Kl_div = kl_div.item(),rec_loss = rec_loss.item()) sleep(0.05)