1. 程式人生 > 其它 >pytorch(二十六):自動編碼器

pytorch(二十六):自動編碼器

一、自動編碼器

1、AE.py

import torch
from torch import nn

class AE(nn.Module):
    def __init__(self):
        super(AE, self).__init__()

        #[b, 784] => [b, 20]
        self.encoder = nn.Sequential(
            nn.Linear(784, 256),
            nn.ReLU(),
            nn.Linear(256, 64),
            nn.ReLU(),
            nn.Linear(
64, 20), nn.ReLU() ) #[b, 20] => [b, 784] self.decoder = nn.Sequential( nn.Linear(20, 64), nn.ReLU(), nn.Linear(64, 256), nn.ReLU(), nn.Linear(256, 784), nn.Sigmoid(), ) def forward(self, x):
""" :param x: [b, 1, 28, 28] :return: """ batchsz = x.shape[0] #flatten x = x.view(batchsz, 784) #encoder x = self.encoder(x) #decoder x = self.decoder(x) #reshape x = x.view(batchsz,1, 28, 28) return
x, None

2、main.py

import torch
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
from auto_encoder import AE
from torch import nn, optim
import visdom
def main():
    mnist_train = datasets.MNIST("mnist", True, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_train = DataLoader(mnist_train, batch_size=32, shuffle=True)

    mnist_test = datasets.MNIST("mnist", False, transform=transforms.Compose([
        transforms.ToTensor()
    ]), download=True)
    mnist_test = DataLoader(mnist_test, batch_size=32, shuffle=True)

    x, _ = iter(mnist_train).__next__()
    print(x.shape)

    model = AE()
    criton = nn.MSELoss()
    optimizer = optim.Adam(model.parameters(), lr=1e-3)
    viz = visdom.Visdom()
    for epoch in range(1000):
        for batchidx, (x, _) in enumerate(mnist_train):
            #[b, 1, 28, 28]
            x_hat, _ = model(x)
            loss = criton(x_hat, x)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
        print(epoch, "loss:",  loss.item())
        x, _ = iter(mnist_test).__next__()
        with torch.no_grad():
            x_hat, _ = model(x)
        viz.images(x, nrow=8, win="x", opts=dict(title = "x"))
        viz.images(x_hat, nrow=8, win="x_hat", opts=dict(title="x_hat"))

if __name__ == '__main__':
    main()