pytorch 自動編碼器
阿新 • • 發佈:2018-11-09
這裡主要使用自動編碼器實現生成資料,以MNIST資料為例。
# -*- coding: utf-8 -*- """ Created on Thu Oct 11 20:34:33 2018 @author: www """ import os import torch from torch.autograd import Variable from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision import transforms as tfs from torchvision.utils import save_image import matplotlib.pyplot as plt #進行資料預處理和迭代器的構建 im_tfs = tfs.Compose([ tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 標準化 ]) train_set = MNIST('E:/data', transform=im_tfs) train_data = DataLoader(train_set, batch_size=128, shuffle=True) # 定義網路 class autoencoder(nn.Module): def __init__(self): super(autoencoder, self).__init__() self.encoder = nn.Sequential( nn.Linear(28*28, 128), nn.ReLU(True), nn.Linear(128, 64), nn.ReLU(True), nn.Linear(64, 12), nn.ReLU(True), nn.Linear(12, 3) # 輸出的 code 是 3 維,便於視覺化 ) self.decoder = nn.Sequential( nn.Linear(3, 12), nn.ReLU(True), nn.Linear(12, 64), nn.ReLU(True), nn.Linear(64, 128), nn.ReLU(True), nn.Linear(128, 28*28), nn.Tanh() ) def forward(self, x): encode = self.encoder(x) decode = self.decoder(encode) return encode, decode net = autoencoder() criterion = nn.MSELoss(size_average=False) optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) def to_img(x): ''' 定義一個函式將最後的結果轉換回圖片 ''' x = 0.5 * (x + 1.) x = x.clamp(0, 1) x = x.view(x.shape[0], 1, 28, 28) return x # 開始訓練自動編碼器 for e in range(100): for im, _ in train_data: im = im.view(im.shape[0], -1) im = Variable(im) # 前向傳播 _, output = net(im) loss = criterion(output, im) / im.shape[0] # 平均 # 反向傳播 optimizer.zero_grad() loss.backward() optimizer.step() if (e+1) % 20 == 0: # 每 20 次,將生成的圖片儲存一下 print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item())) pic = to_img(output.cpu().data) if not os.path.exists('./simple_autoencoder'): os.mkdir('./simple_autoencoder') save_image(pic, './simple_autoencoder/image_{}.png'.format(e + 1)) code = Variable(torch.FloatTensor([[1.19, -3.36, 2.06]])) # 給一個 code 是 (1.19, -3.36, 2.06) decode = net.decoder(code) decode_img = to_img(decode).squeeze() decode_img = decode_img.data.numpy() * 255 plt.imshow(decode_img.astype('uint8'), cmap='gray') # 生成圖片 3 #當然,比較好的方式是使用卷積神經網路。這裡寫一個模型 class conv_autoencoder(nn.Module): def __init__(self): super(conv_autoencoder, self).__init__() self.encoder = nn.Sequential( nn.Conv2d(1, 16, 3, stride=3, padding=1), # (b, 16, 10, 10) nn.ReLU(True), nn.MaxPool2d(2, stride=2), # (b, 16, 5, 5) nn.Conv2d(16, 8, 3, stride=2, padding=1), # (b, 8, 3, 3) nn.ReLU(True), nn.MaxPool2d(2, stride=1) # (b, 8, 2, 2) ) self.decoder = nn.Sequential( nn.ConvTranspose2d(8, 16, 3, stride=2), # (b, 16, 5, 5) nn.ReLU(True), nn.ConvTranspose2d(16, 8, 5, stride=3, padding=1), # (b, 8, 15, 15) nn.ReLU(True), nn.ConvTranspose2d(8, 1, 2, stride=2, padding=1), # (b, 1, 28, 28) nn.Tanh() ) def forward(self, x): encode = self.encoder(x) decode = self.decoder(encode) return encode, decode