1. 程式人生 > >InfoGAN的簡易實現

InfoGAN的簡易實現

numpy 沒有 randint 轉換 gradient nac ace matplot eat

這裏求最大化互信息沒有共享D網絡,直接使用了一個簡單的mlp神經網絡Q

import os, sys
sys.path.append("/home/hxj/anaconda3/lib/python3.6/site-packages")
import torch
import torch.nn.functional as nn
import torch.autograd as autograd
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import os from torch.autograd import Variable from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets(./MNIST_data, one_hot=True) mb_size = 32 Z_dim = 16 X_dim = mnist.train.images.shape[1] #784 y_dim = mnist.train.labels.shape[1] #10 h_dim = 128 cnt = 0 lr = 1e-3 def
xavier_init(size): in_dim = size[0] xavier_stddev = 1. / np.sqrt(in_dim / 2.) return Variable(torch.randn(*size) * xavier_stddev, requires_grad=True) """ ==================== GENERATOR ======================== """ Wzh = xavier_init(size=[Z_dim + 10, h_dim]) #shape 26 * 128 bzh = Variable(torch.zeros(h_dim), requires_grad=True) Whx
= xavier_init(size=[h_dim, X_dim]) #shape 128 * 784 bhx = Variable(torch.zeros(X_dim), requires_grad=True) def G(z, c): inputs = torch.cat([z, c], 1) h = nn.relu(inputs @ Wzh + bzh.repeat(inputs.size(0), 1)) X = nn.sigmoid(h @ Whx + bhx.repeat(h.size(0), 1)) return X """ ==================== DISCRIMINATOR ======================== """ Wxh = xavier_init(size=[X_dim, h_dim]) bxh = Variable(torch.zeros(h_dim), requires_grad=True) Why = xavier_init(size=[h_dim, 1]) bhy = Variable(torch.zeros(1), requires_grad=True) def D(X): h = nn.relu(X @ Wxh + bxh.repeat(X.size(0), 1)) y = nn.sigmoid(h @ Why + bhy.repeat(h.size(0), 1)) return y """ ====================== Q(c|X) ========================== """ Wqxh = xavier_init(size=[X_dim, h_dim]) bqxh = Variable(torch.zeros(h_dim), requires_grad=True) Whc = xavier_init(size=[h_dim, 10]) bhc = Variable(torch.zeros(10), requires_grad=True) def Q(X): h = nn.relu(X @ Wqxh + bqxh.repeat(X.size(0), 1)) c = nn.softmax(h @ Whc + bhc.repeat(h.size(0), 1)) return c G_params = [Wzh, bzh, Whx, bhx] D_params = [Wxh, bxh, Why, bhy] Q_params = [Wqxh, bqxh, Whc, bhc] params = G_params + D_params + Q_params """ ===================== TRAINING ======================== """ def reset_grad(): for p in params: if p.grad is not None: data = p.grad.data p.grad = Variable(data.new().resize_as_(data).zero_()) G_solver = optim.Adam(G_params, lr=1e-3) D_solver = optim.Adam(D_params, lr=1e-3) Q_solver = optim.Adam(G_params + Q_params, lr=1e-3) def sample_c(size): c = np.random.multinomial(1, 10*[0.1], size=size) c = Variable(torch.from_numpy(c.astype(float32))) return c for it in range(100000): # Sample data X, _ = mnist.train.next_batch(mb_size) # 32 X = Variable(torch.from_numpy(X)) #將數組轉換為列向量 32*784 z = Variable(torch.randn(mb_size, Z_dim))# 32 16 隨機二維數組 c = sample_c(mb_size) # 32 10的標簽 隨機標簽 print(z.shape) print(c.shape) sys.exit() # Dicriminator forward-loss-backward-update G_sample = G(z, c) D_real = D(X) D_fake = D(G_sample) D_loss = -torch.mean(torch.log(D_real + 1e-8) + torch.log(1 - D_fake + 1e-8)) D_loss.backward() D_solver.step() # Housekeeping - reset gradient reset_grad() # Generator forward-loss-backward-update G_sample = G(z, c) D_fake = D(G_sample) G_loss = -torch.mean(torch.log(D_fake + 1e-8)) G_loss.backward() G_solver.step() # Housekeeping - reset gradient reset_grad() # Q forward-loss-backward-update G_sample = G(z, c) #在c標簽下生成的假樣本,除了用來訓練G和D之外,還要經過神經網絡Q Q_c_given_x = Q(G_sample) # 讓標簽和經過Q生成的值之間的互信息最大 crossent_loss = torch.mean(-torch.sum(c * torch.log(Q_c_given_x + 1e-8), dim=1)) mi_loss = crossent_loss mi_loss.backward() Q_solver.step() # Housekeeping - reset gradient reset_grad() # Print and plot every now and then if it % 1000 == 0: idx = np.random.randint(0, 10) c = np.zeros([mb_size, 10]) c[range(mb_size), idx] = 1 c = Variable(torch.from_numpy(c.astype(float32))) samples = G(z, c).data.numpy()[:16] print(Iter-{}; D_loss: {}; G_loss: {}; Idx: {} .format(it, D_loss.data.numpy(), G_loss.data.numpy(), idx)) fig = plt.figure(figsize=(4, 4)) gs = gridspec.GridSpec(4, 4) gs.update(wspace=0.05, hspace=0.05) for i, sample in enumerate(samples): ax = plt.subplot(gs[i]) plt.axis(off) ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect(equal) plt.imshow(sample.reshape(28, 28), cmap=Greys_r) if not os.path.exists(out/): os.makedirs(out/) plt.savefig(out/{}.png .format(str(cnt).zfill(3)), bbox_inches=tight) cnt += 1 plt.close(fig)

InfoGAN的簡易實現