1. 程式人生 > >【Keras】GAN神經網路

【Keras】GAN神經網路

參考文獻:

主要參考這篇文章 Generative Adversarial Networks, link

為了方便說明和研究,我這裡只是設計了一個非常簡單的模型,對高斯分佈樣本進行生成。不過從下面的實驗中,我還是發現了一些非常有用的特點,可以加深我們對GAN網路的瞭解。

GAN原理

具體原理可以參考上面的文獻,不過在這裡還是大概講一下。
其實GAN的原理非常簡單,它有2個子網路組成,一個是Generator,即生成網路,它以噪音樣本為輸入,通過學習到的權重,把噪音轉變(即生成)為有意義的訊號;一個是Discriminator,即判別網路,他以訊號為輸入(可以來自generator生成的訊號,也可以是真正的訊號),通過學習來判別訊號的真假,並輸出一個0-1之間的概率。可以把Generator比喻為一個假的印鈔機,而Discriminator則是驗鈔機,他們兩個互相競爭,使得印鈔機越來越真,同時驗鈔機也越來越準。但是最終我們是希望Generator越來越真,而Discriminator的輸出都是0.5,即難以分辨~~

而在訓練的時候,則分兩個階段進行,第一個階段是Discriminator的學習,此時固定Generator的權重不變,只更新Discriminator的權重。loss函式是:

其中m是batch_size, x表示真正的訊號,z表示噪音樣本。訓練時分別從噪音分佈和真實分佈中選出m個噪音輸入樣本和m個真實訊號樣本,通過對以上的loss function最大化更新Discriminator的權重

第二個階段是對Generator進行訓練,此時的loss function是:

不過,此時是對loss最小化來更新Generator的權重。

另外,這2個階段並不是交替進行的,而是執行K次Discriminator的更新,再執行1次Generator的更新。
後面的實驗結果也顯示,K的選擇非常關鍵。

具體實現

主要工具是 python + keras,用keras實現一些常用的網路特別容易,比如MLP、word2vec、LeNet、lstm等等,github上都有詳細demo。但是稍微複雜些的就要費些時間自己寫了。不過整體看,依然比用原生tf寫要方便。而且,我們還可以把keras當初是學習tf的參考程式碼,裡面很多寫法都非常值得借鑑。

廢話不多說了,直接上程式碼吧:

GANmodel

只列出最主要的程式碼


# 這是針對GAN特殊設計的loss function
def log_loss_discriminator(y_true, y_pred):
    return - K.log(K.maximum(K.epsilon(), y_pred))
    
def log_loss_generator(y_true, y_pred):
    return K.log(K.maximum(K.epsilon(), 1. - y_pred))
    
class GANModel:
    def __init__(self, 
                 input_dim,
                 log_dir = None):
        '''
            __tensor[0]: 定義了discriminateor的表示式,  對y進行判別,true samples
            __tensor[1]: 定義了generator的表示式, 對x進行生成,noise samples
        '''
        if isinstance(input_dim, list):
            input_dim_y, input_dim_x = input_dim[0], input_dim[1]
        elif isinstance(input_dim, int):
            input_dim_x = input_dim_y = input_dim
        else:
            raise ValueError("input_dim should be list or interger, got %r" % input_dim) 
        # 必須使用名字,方便後面分別輸入2個訊號
        self.__inputs = [layers.Input(shape=(input_dim_y,), name = "y"), 
                            layers.Input(shape=(input_dim_x,), name = "x")]
        self.__tensors = [None, None] 
        self.log_dir = log_dir
        self._discriminate_layers = []
        self._generate_layers = []
        self.train_status = defaultdict(list)
        
    def add_gen_layer(self, layer):
        self._add_layer(layer, True)
    def add_discr_layer(self, layer):
        self._add_layer(layer)
    def _add_layer(self, layer, for_gen=False):
        idx = 0
        if for_gen:
            self._generate_layers.append(layer)
            idx = 1
        else:
            self._discriminate_layers.append(layer)
        
        if self.__tensors[idx] is None:
            self.__tensors[idx] = layer(self.__inputs[idx])
        else:
            self.__tensors[idx] = layer(self.__tensors[idx])
            
    def compile_discriminateor_model(self, optimizer = optimizers.Adam()):
        if len(self._discriminate_layers) <= 0:
            raise ValueError("you need to build discriminateor model before compile it")
        if len(self._generate_layers) <= 0:
            raise ValueError("you need to build generator model before compile discriminateo model")
        # 通過指定trainable = False,可以freeze權重的更新。必須放在compile之前
        for l in self._discriminate_layers:
            l.trainable = True
        for l in self._generate_layers:
            l.trainable = False
        discriminateor_out1 = self.__tensors[0]
        discriminateor_out2 = layers.Lambda(lambda y: 1. - y)(self._discriminate_generated())
        # 如果輸出2個訊號,keras會分別在各個訊號上引用loss function,然後累加,對累加的結果進行
        # minimize 更新。雙下劃線的model是參與訓練的模型。
        self.__discriminateor_model = Model(self.__inputs, [discriminateor_out1, discriminateor_out2])
        self.__discriminateor_model.compile(optimizer, 
                                     loss = log_loss_discriminator)
       
        # 這個才是真正的discriminator model 
        self.discriminateor_model = Model(self.__inputs[0], self.__tensors[0])
        self.discriminateor_model.compile(optimizer, 
                                     loss = log_loss_discriminator)
        if self.log_dir is not None:
            # 需要安裝pydot和graphviz。沒有的可以先註釋掉
            plot_model(self.__discriminateor_model, self.log_dir + "/gan_discriminateor_model.png", show_shapes = True) 
        
    def compile_generator_model(self, optimizer = optimizers.Adam()):
        if len(self._discriminate_layers) <= 0:
            raise ValueError("you need to build discriminateor model before compile generator model")
        if len(self._generate_layers) <= 0:
            raise ValueError("you need to build generator model before compile it")
        
        for l in self._discriminate_layers:
            l.trainable = False
        for l in self._generate_layers:
            l.trainable = True
              
        out = self._discriminate_generated()
        self.__generator_model = Model(self.__inputs[1], out)
        self.__generator_model.compile(optimizer, 
                                     loss = log_loss_generator)
        # 這個才是真正的Generator模型
        self.generator_model = Model(self.__inputs[1], self.__tensors[1])
        if self.log_dir is not None:
            plot_model(self.__generator_model, self.log_dir + "/gan_generator_model.png", show_shapes = True) 

    def train(self, sample_list, epoch = 3, batch_size = 32, step_per = 10, plot=False):
        '''
        step_per: 每隔幾步訓練一次generator,即K
        '''
        sample_noise, sample_true = sample_list["x"], sample_list["y"]
        sample_count = sample_noise.shape[0]
        batch_count = sample_count // batch_size 
        # 這裡比較trick了,因為keras的model必須要一個y。但是gan其實是沒有y的。只好偽造一個
        # 滿足keras的“無理”要求
        psudo_y = np.ones((batch_size, ), dtype = 'float32')
        if plot:
            # plot the real data
            fig = plt.figure()
            ax = fig.add_subplot(1,1,1)
            plt.ion()
            plt.show() 
        for ei in range(epoch):
            for i in range(step_per):
                idx = random.randint(0, batch_count-1)
                batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size]
                idx = random.randint(0, batch_count-1)
                batch_sample = sample_true[idx * batch_size : (idx+1) * batch_size]
                self.__discriminateor_model.train_on_batch({
                    "y":  batch_sample,
                    "x": batch_noise}, 
                    [psudo_y, psudo_y])

            idx = random.randint(0, batch_count-1)
            batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size]
            self.__generator_model.train_on_batch(batch_noise, psudo_y)
            
            if plot:
                gen_result = self.generator_model.predict_on_batch(batch_noise)
                self.train_status["gen_result"].append(gen_result)
                dis_result = self.discriminateor_model.predict_on_batch(gen_result)
                self.train_status["dis_result"].append(dis_result)
                freq_g, bin_g = np.histogram(gen_result, density=True)
                # norm to sum1
                freq_g = freq_g * (bin_g[1] - bin_g[0])
                bin_g = bin_g[:-1]
                freq_d, bin_d = np.histogram(batch_sample, density=True)
                freq_d = freq_d * (bin_d[1] - bin_d[0])
                bin_d = bin_d[:-1]
                ax.plot(bin_g, freq_g, 'go-', markersize = 4)
                ax.plot(bin_d, freq_d, 'ko-', markersize = 8)
                gen1d = gen_result.flatten()
                dis1d = dis_result.flatten()
                si = np.argsort(gen1d)
                ax.plot(gen1d[si], dis1d[si], 'r--')
                if (ei+1) % 20 == 0:
                    ax.cla()
                plt.title("epoch = %d" % (ei+1))
                plt.pause(0.05)
        if plot:
            plt.ioff()
            plt.close()

main部分

只列出主要部分:從中可以看到主要模型結構和引數取值

    step_per = 20
    sample_size = args.batch_size * 100

    # 整個測試樣本集合
    noise_dim = 4
    signal_dim = 1
    x = np.random.uniform(-3, 3, size = (sample_size, noise_dim))
    y = np.random.normal(size = (sample_size, signal_dim))
    samples = {"x": x, 
               "y": y}
    
    gan = GANModel([signal_dim, noise_dim], args.log_dir)
    gan.add_discr_layer(layers.Dense(200, activation="relu"))
    gan.add_discr_layer(layers.Dense(50, activation="softmax"))
    gan.add_discr_layer(layers.Lambda(lambda y: K.max(y, axis=-1, keepdims=True),
                                 output_shape = (1,)))

    gan.add_gen_layer(layers.Dense(200, activation="relu"))
    gan.add_gen_layer(layers.Dense(100, activation="relu"))
    gan.add_gen_layer(layers.Dense(50, activation="relu"))
    gan.add_gen_layer(layers.Dense(signal_dim))
    
    gan.compile_generator_model()
    loger.info("compile generator finished")
    gan.compile_discriminateor_model()
    loger.info("compile discriminator finished")
    
    gan.train(samples, args.epoch, args.batch_size, step_per, plot=True)

完整程式碼

# demo_gan.py
# -*- encoding: utf8 -*-
'''
GAN網路Demo
'''
import os
from os import path
import argparse
import logging
import traceback
import random
import pickle
import numpy as np
import tensorflow as tf
from keras import optimizers 
from keras import layers
from keras import callbacks, regularizers, activations
from keras.engine import Model
from keras.utils.vis_utils import plot_model
import keras.backend as K
from collections import defaultdict
from matplotlib import pyplot as plt
import app_logger

loger = logging.getLogger(__name__)

# 注意pred不能為負數,因為pred是一個概率。所以最後一個啟用函式的選擇要注意
def log_loss_discriminator(y_true, y_pred):
    return - K.log(K.maximum(K.epsilon(), y_pred))
    
def log_loss_generator(y_true, y_pred):
    return K.log(K.maximum(K.epsilon(), 1. - y_pred))

class GANModel:
    def __init__(self, 
                 input_dim,
                 log_dir = None):
        '''
            __tensor[0]: 定義了discriminateor的表示式
            __tensor[1]: 定義了generator的表示式
        '''
        # discriminateor 對y進行判別,true samples
        # generator 對x進行生成,noise samples
        if isinstance(input_dim, list):
            input_dim_y, input_dim_x = input_dim[0], input_dim[1]
        elif isinstance(input_dim, int):
            input_dim_x = input_dim_y = input_dim
        else:
            raise ValueError("input_dim should be list or interger, got %r" % input_dim) 
    
        self.__inputs = [layers.Input(shape=(input_dim_y,), name = "y"), 
                            layers.Input(shape=(input_dim_x,), name = "x")]
        self.__tensors = [None, None] 
        self.log_dir = log_dir
        self._discriminate_layers = []
        self._generate_layers = []
        self.train_status = defaultdict(list)
        
    def add_gen_layer(self, layer):
        self._add_layer(layer, True)
    def add_discr_layer(self, layer):
        self._add_layer(layer)
    def _add_layer(self, layer, for_gen=False):
        idx = 0
        if for_gen:
            self._generate_layers.append(layer)
            idx = 1
        else:
            self._discriminate_layers.append(layer)
        
        if self.__tensors[idx] is None:
            self.__tensors[idx] = layer(self.__inputs[idx])
        else:
            self.__tensors[idx] = layer(self.__tensors[idx])
            
    def compile_discriminateor_model(self, optimizer = optimizers.Adam()):
        if len(self._discriminate_layers) <= 0:
            raise ValueError("you need to build discriminateor model before compile it")
        if len(self._generate_layers) <= 0:
            raise ValueError("you need to build generator model before compile discriminateo model")
        
        for l in self._discriminate_layers:
            l.trainable = True
        for l in self._generate_layers:
            l.trainable = False
        discriminateor_out1 = self.__tensors[0]
        discriminateor_out2 = layers.Lambda(lambda y: 1. - y)(self._discriminate_generated())
        self.__discriminateor_model = Model(self.__inputs, [discriminateor_out1, discriminateor_out2])
        self.__discriminateor_model.compile(optimizer, 
                                     loss = log_loss_discriminator)
       
        # 這個才是需要的discriminateor model 
        self.discriminateor_model = Model(self.__inputs[0], self.__tensors[0])
        self.discriminateor_model.compile(optimizer, 
                                     loss = log_loss_discriminator)
        #if self.log_dir is not None:
        #    plot_model(self.__discriminateor_model, self.log_dir + "/gan_discriminateor_model.png", show_shapes = True) 
        
    def compile_generator_model(self, optimizer = optimizers.Adam()):
        if len(self._discriminate_layers) <= 0:
            raise ValueError("you need to build discriminateor model before compile generator model")
        if len(self._generate_layers) <= 0:
            raise ValueError("you need to build generator model before compile it")
        
        for l in self._discriminate_layers:
            l.trainable = False
        for l in self._generate_layers:
            l.trainable = True
              
        out = self._discriminate_generated()
        self.__generator_model = Model(self.__inputs[1], out)
        self.__generator_model.compile(optimizer, 
                                     loss = log_loss_generator)
        # 這個才是真正需要的模型
        self.generator_model = Model(self.__inputs[1], self.__tensors[1])
        #if self.log_dir is not None:
        #    plot_model(self.__generator_model, self.log_dir + "/gan_generator_model.png", show_shapes = True) 

    def train(self, sample_list, epoch = 3, batch_size = 32, step_per = 10, plot=False):
        '''
        step_per: 每隔幾步訓練一次generator
        '''
        sample_noise, sample_true = sample_list["x"], sample_list["y"]
        sample_count = sample_noise.shape[0]
        batch_count = sample_count // batch_size 
        psudo_y = np.ones((batch_size, ), dtype = 'float32')
        if plot:
            # plot the real data
            fig = plt.figure()
            ax = fig.add_subplot(1,1,1)
            plt.ion()
            plt.show() 
        for ei in range(epoch):
            for i in range(step_per):
                idx = random.randint(0, batch_count-1)
                batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size]
                idx = random.randint(0, batch_count-1)
                batch_sample = sample_true[idx * batch_size : (idx+1) * batch_size]
                self.__discriminateor_model.train_on_batch({
                    "y":  batch_sample,
                    "x": batch_noise}, 
                    [psudo_y, psudo_y])

            idx = random.randint(0, batch_count-1)
            batch_noise = sample_noise[idx * batch_size : (idx+1) * batch_size]
            self.__generator_model.train_on_batch(batch_noise, psudo_y)
            
            if plot:
                gen_result = self.generator_model.predict_on_batch(batch_noise)
                self.train_status["gen_result"].append(gen_result)
                dis_result = self.discriminateor_model.predict_on_batch(gen_result)
                self.train_status["dis_result"].append(dis_result)
                freq_g, bin_g = np.histogram(gen_result, density=True)
                # norm to sum1
                freq_g = freq_g * (bin_g[1] - bin_g[0])
                bin_g = bin_g[:-1]
                freq_d, bin_d = np.histogram(batch_sample, density=True)
                freq_d = freq_d * (bin_d[1] - bin_d[0])
                bin_d = bin_d[:-1]
                ax.plot(bin_g, freq_g, 'go-', markersize = 4)
                ax.plot(bin_d, freq_d, 'ko-', markersize = 8)
                gen1d = gen_result.flatten()
                dis1d = dis_result.flatten()
                si = np.argsort(gen1d)
                ax.plot(gen1d[si], dis1d[si], 'r--')
                if (ei+1) % 20 == 0:
                    ax.cla()
                plt.title("epoch = %d" % (ei+1))
                plt.pause(0.05)
        if plot:
            plt.ioff()
            plt.close()
            
            
    def save_model(self, path_dir):
        self.generator_model.save(path_dir + "/gan_generator.h5")
        self.discriminateor_model.save(path_dir + "/gan_discriminateor.h5")
    
    def load_model(self, path_dir):
        from keras.models import load_model
        custom_obj = {
            "log_loss_discriminateor": log_loss_discriminateor,
            "log_loss_generator": log_loss_generator}
        self.generator_model = load_model(path_dir + "/gan_generator.h5", custom_obj)
        self.discriminateor_model = load_model(path_dir + "/gan_discriminateor.h5", custom_obj)
    
    def _discriminate_generated(self):
        # 必須每次重新生成一下 
        disc_t = self.__tensors[1]
        for l in self._discriminate_layers:
            disc_t = l(disc_t)            
        return disc_t
    
if __name__ == "__main__":
    parser = argparse.ArgumentParser("""gan model demo (gaussian sample)""")
    parser.add_argument("-m", "--model_dir")
    parser.add_argument("-log", "--log_dir")
    parser.add_argument("-b", "--batch_size", type = int, default = 32)
    parser.add_argument("-log_lvl", "--log_lvl", default = "info",
                        metavar = "可以指定INFO,DEBUG,WARN, ERROR")
    parser.add_argument("-e", "--epoch", type = int, default = 10)
    
    args = parser.parse_args()
    
    log_lvl = {"info": logging.INFO,
               "debug": logging.DEBUG,
               "warn": logging.WARN,
               "warning": logging.WARN,
               "error": logging.ERROR,
               "err": logging.ERROR}[args.log_lvl.lower()]
    app_logger.init(log_lvl)
        
    loger.info("args: %r" % args)
    step_per = 20
    sample_size = args.batch_size * 100

    # 整個測試樣本集合
    noise_dim = 4
    signal_dim = 1
    x = np.random.uniform(-3, 3, size = (sample_size, noise_dim))
    y = np.random.normal(size = (sample_size, signal_dim))
    samples = {"x": x, 
               "y": y}
    
    gan = GANModel([signal_dim, noise_dim], args.log_dir)
    gan.add_discr_layer(layers.Dense(200, activation="relu"))
    gan.add_discr_layer(layers.Dense(50, activation="softmax"))
    gan.add_discr_layer(layers.Lambda(lambda y: K.max(y, axis=-1, keepdims=True),
                                 output_shape = (1,)))

    gan.add_gen_layer(layers.Dense(200, activation="relu"))
    gan.add_gen_layer(layers.Dense(100, activation="relu"))
    gan.add_gen_layer(layers.Dense(50, activation="relu"))
    gan.add_gen_layer(layers.Dense(signal_dim))
    
    gan.compile_generator_model()
    loger.info("compile generator finished")
    gan.compile_discriminateor_model()
    loger.info("compile discriminator finished")
    
    gan.train(samples, args.epoch, args.batch_size, step_per, plot=True)
    gen_results = gan.train_status["gen_result"]
    dis_results = gan.train_status["dis_result"]

    gen_result = gen_results[-1]
    dis_result = dis_results[-1]
    freq_g, bin_g = np.histogram(gen_result, density=True)
    # norm to sum1
    freq_g = freq_g * (bin_g[1] - bin_g[0])
    bin_g = bin_g[:-1]
    freq_d, bin_d = np.histogram(y, bins = 100, density=True)
    freq_d = freq_d * (bin_d[1] - bin_d[0])
    bin_d = bin_d[:-1]
    plt.plot(bin_g, freq_g, 'go-', markersize = 4)
    plt.plot(bin_d, freq_d, 'ko-', markersize = 8)
    gen1d = gen_result.flatten()
    dis1d = dis_result.flatten()
    si = np.argsort(gen1d)
    plt.plot(gen1d[si], dis1d[si], 'r--')
    plt.savefig("img/gan_results.png")
    if not path.exists(args.model_dir):
        os.mkdir(args.model_dir)
    gan.save_model(args.model_dir)


# app_logger.py
import logging

def init(lvl=logging.DEBUG):
    log_handler = logging.StreamHandler()
    # create formatter
    formatter = logging.Formatter('[%(asctime)s] %(levelname)s %(filename)s:%(funcName)s:%(lineno)d > %(message)s')
    log_handler.setFormatter(formatter)
    logging.basicConfig(level = lvl, handlers = [log_handler])