1. 程式人生 > >CNTK API文件翻譯(20)——GAN處理MSIST資料基礎

CNTK API文件翻譯(20)——GAN處理MSIST資料基礎

完成本期教程需要完成本系列的第四篇教程。

介紹

生成模型在深度學習的半監督或者非監督學習領域引起了廣泛的專注,這些領域傳統上都是使用判別模型的。生成模型的思想是線收集某個研究領域巨量的資料,然後訓練得到一個可以生成這樣的資料集的模型。這是一個需要大量訓練和海量資料的熱門研究領域。根據OpenAI部落格的觀點,這種方法可能可以用於進行計算機輔助藝術的創作,或者根據語言描述來對圖片進行一些改變比如“讓我的笑容更明媚”。這種方法目前已被用於影象去燥、影象修復、增加影象解析度、影象結構識別,而且在增強學習、神經網路預訓練這種標記資料代價高昂的領域,也有深入的研究。

生成模型能夠產生與現實資料高度相似的內容(影象,聲音等)是非常困難的。生成對抗網路(Generative Adversarial Network,GAN)是實現上訴描述的方法之一。一個來自LeCun summarizes的文章(地址:

https://www.quora.com/What-are-some-recent-and-potentially-upcoming-breakthroughs-in-deep-learning)總結了GAN和GAN近十年的發展,我們在此展示如何使用CNTK來建立簡單的GAN來生成模擬的MNIST資料。

import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import os

import cntk as C
import cntk.tests.test_utils

# (only needed for our build system)
cntk.tests.test_utils.set_device_from_pytest_env() # fix a random seed for CNTK components C.cntk_py.set_fixed_random_seed(1)

我們設定了兩種執行模式:

  • 快速模式:isFast變數設定成True。這是我們的預設模式,在這個模式下我們會訓練更少的次數,也會使用更少的資料,這個模式保證功能的正確性,但訓練的結果還遠遠達不到可用的要求。
  • 慢速模式:我們建議學習者在學習的時候試試將isFast變數設定成False,這會讓學習者更加了解本教程的內容。

  • 注意如果isFast被設為False,在有GPU的機器上程式碼將執行幾個小時。你可以試試通過吧num_minibatches設定成一個較小的數字比如20000,減少迴圈次數,不過帶來的代價就是生成影象質量的降低。
isFast = True

資料讀取

GAN網路的輸入將會是一個由隨機陣列成的向量。在訓練結束是,GNA學會生成像MNIST資料集中一樣的手寫數字的圖片。我們將使用與第四期下載的資料,一些關於資料格式的討論和讀取方法在之前的教程中有涉及到。在本教程中,只要知道下面的方法返回一個用來從MNIST資料集中生成影象的物件。因為我們是在建立一個非監督學習模型,我們只讀取features,而不管labels。

# Ensure the training data is generated and available for this tutorial
# We search in two locations in the toolkit for the cached MNIST data set.

data_found = False
for data_dir in [os.path.join("..", "Examples", "Image", "DataSets", "MNIST"),
                 os.path.join("data", "MNIST")]:
    train_file = os.path.join(data_dir, "Train-28x28_cntk_text.txt")
    if os.path.isfile(train_file):
        data_found = True
        break

if not data_found:
    raise ValueError("Please generate the data by completing CNTK 103 Part A")

print("Data directory is {0}".format(data_dir))


def create_reader(path, is_training, input_dim, label_dim):
    deserializer = C.io.CTFDeserializer(
        filename = path,
        streams = C.io.StreamDefs(
            labels_unused = C.io.StreamDef(field = 'labels', shape = label_dim, is_sparse = False),
            features = C.io.StreamDef(field = 'features', shape = input_dim, is_sparse = False
            )
        )
    )
    return C.io.MinibatchSource(
        deserializers = deserializer,
        randomize = is_training,
        max_sweeps = C.io.INFINITELY_REPEAT if is_training else 1
    )

我們用於訓練GAN的隨機噪音使用noise_sample方法隨機生成一些[-1,1]之間正態分佈的噪音樣本。

np.random.seed(123)
def noise_sample(num_samples):
    return np.random.uniform(
        low = -1.0,
        high = 1.0,
        size = [num_samples, g_input_dim]        
    ).astype(np.float32)

模型建立

GNA網路由兩個子網路組成,一個叫生成器(Generator,G),一個叫判別器(Discriminator ,D)。

  • 生成器以隨機噪音向量z為輸入引數,努力生成與MNIST資料集中的真實影象($x$)相似的合成影象($x^*$)
  • 判別器努力區分真實影象($x$)和合成影象($x^*$)之間的區別。

image

在每輪訓練中,生成器都會生成更加真實的合成影象(也就是減少合成影象和真是影象之間的差),同時判別器最大化給真實和生成的影象帖對真實或生成的標籤的概率。GNA兩個子網路中的衝突導致他收斂於一個平衡,此時生成器生成看起來很像MNIST影象的合成照片,判別器可以最多的隨機猜測那個圖片是真實的,哪個圖片是合成的。訓練的結果就是生成模型以隨機的輸入數字得到逼真的MNIST影象。

模型配置

首先,我們設定一些模型結構和訓練超引數。

  • 生成網路是一個有一個隱藏層的全連線網路,輸入資料是一個100維隨機向量,輸出資料會是一個784維的向量,對應28×28影象的扁平狀態。判別器也是一個單層全連線網路,以生成器生成的784維向量或來自真實MNIST資料集的784維向量作為輸入,輸出一個代表輸入資料是真實MNIST資料概率的標量。

模型構成

我們為我們的模型構建計算圖,一個給生成器一個給判別器。首先我們我們建立一些模型結構引數。

  • 生成器輸入100維隨機向量($z$)輸出一個784維的向量,對應28×28合成影象($x^*$)的扁平狀態。在本教程中,我們簡單將我們的生成器構造為兩個全連線層。我們在最後一層使用tanh啟用函式確保生成器函式的輸出值在閉區間[-1,1]之間。因為之前也將MNIST影象對映到了這個範圍內,所以這步操作是有必要的。
  • 判別器輸入從生成器中輸出的或者來自真實MNIST影象的784維向量($x^*$),輸出輸入影象是真實MNIST影象的概率。我們也使用兩個全連線層構建判別器,最後一層使用sigmoid啟用函式,以此保證判別器的輸出值是一個有效的概率。
# architectural parameters
g_input_dim = 100
g_hidden_dim = 128
g_output_dim = d_input_dim = 784
d_hidden_dim = 128
d_output_dim = 1

def generator(z):
    with C.layers.default_options(init = C.xavier()):
        h1 = C.layers.Dense(g_hidden_dim, activation = C.relu)(z)
        return C.layers.Dense(g_output_dim, activation = C.tanh)(h1)

def discriminator(x):
    with C.layers.default_options(init = C.xavier()):
        h1 = C.layers.Dense(d_hidden_dim, activation = C.relu)(x)
        return C.layers.Dense(d_output_dim, activation = C.sigmoid)(h1)

我們使用的取樣包數大小是1024,固定學習速率0.0005.如果使用快速模式我們只訓練300輪以證明其功能正確性。

注意:在慢速模式,結果看起來會比快速模式好得多,不過根據你訓練電腦的配置,你可能會登上幾個小時到十幾個小時不等。一般來說,取樣包訓練的越多,生成的影象越逼真。

# training config
minibatch_size = 1024
num_minibatches = 300 if isFast else 40000
lr = 0.00005

構建計算圖

計算圖的剩下部分主要用於協調訓練演算法和引數更新,這由於以下原因對GAN十分困難。

  • 第一,判別器必須既用於真實MNIST影象,也用於生成器函式生成的模擬影象。一種在計算圖上記錄上訴狀態的方法是建立一個判別器函式輸出的克隆副本,但是用不同的輸入。在副本函式中設定method=share確保不同方式使用的判別器使用一樣的引數。
  • 第二,我們需要對生成器和判別器使用不同的成本函式來更新模型引數。我們可以通過parameters屬性獲取計算圖中函式物件的引數。然而,當更新模型引數時,更新只發生在兩個子網路中的一個,另一個沒有改變。換句話說,當更新生成器的引數時,我們只更新了G函式的引數,沒有更新D函式的引數。

訓練模型

訓練GAN的程式碼與2014年神經資訊處理系統大會(NIPS)上的一篇論文(連結:https://arxiv.org/pdf/1406.2661v1.pdf)提出的演算法非常接近。在實現是,我們訓練D來最大化給訓練樣本和G中生產的樣本貼正確標籤的概率。換句話說,D和G在玩一個雙人針對函式V(G,D)極大極小值遊戲。

minGmaxDV(D,G)=Ex[logD(x)]+Ez[log(1D(G(z)))]

這個遊戲的最優點,生成器將生成非常逼真的資料,判別器預測合成圖片的概率將會變成0.5。上面提到的論文中提到的演算法會在下面的程式碼中實現。

image

ef build_graph(noise_shape, image_shape,
                G_progress_printer, D_progress_printer):
    input_dynamic_axes = [C.Axis.default_batch_axis()]
    Z = C.input_variable(noise_shape, dynamic_axes=input_dynamic_axes)
    X_real = C.input_variable(image_shape, dynamic_axes=input_dynamic_axes)
    X_real_scaled = 2*(X_real / 255.0) - 1.0

    # Create the model function for the generator and discriminator models
    X_fake = generator(Z)
    D_real = discriminator(X_real_scaled)
    D_fake = D_real.clone(
        method = 'share',
        substitutions = {X_real_scaled.output: X_fake.output}
    )

    # Create loss functions and configure optimazation algorithms
    G_loss = 1.0 - C.log(D_fake)
    D_loss = -(C.log(D_real) + C.log(1.0 - D_fake))

    G_learner = C.fsadagrad(
        parameters = X_fake.parameters,
        lr = C.learning_rate_schedule(lr, C.UnitType.sample),
        momentum = C.momentum_as_time_constant_schedule(700)
    )
    D_learner = C.fsadagrad(
        parameters = D_real.parameters,
        lr = C.learning_rate_schedule(lr, C.UnitType.sample),
        momentum = C.momentum_as_time_constant_schedule(700)
    )

    # Instantiate the trainers
    G_trainer = C.Trainer(
        X_fake,
        (G_loss, None),
        G_learner,
        G_progress_printer
    )
    D_trainer = C.Trainer(
        D_real,
        (D_loss, None),
        D_learner,
        D_progress_printer
    )

    return X_real, X_fake, Z, G_trainer, D_trainer

隨著定義值函式,我們開始對GAN模型進行間接訓練。訓練這個模型根據硬體狀況將會話費很長時間特別是如果你把isFast設為False。

def train(reader_train):
    k = 2

    # print out loss for each model for upto 50 times
    print_frequency_mbsize = num_minibatches // 50
    pp_G = C.logging.ProgressPrinter(print_frequency_mbsize)
    pp_D = C.logging.ProgressPrinter(print_frequency_mbsize * k)

    X_real, X_fake, Z, G_trainer, D_trainer = \
        build_graph(g_input_dim, d_input_dim, pp_G, pp_D)

    input_map = {X_real: reader_train.streams.features}
    for train_step in range(num_minibatches):

        # train the discriminator model for k steps
        for gen_train_step in range(k):
            Z_data = noise_sample(minibatch_size)
            X_data = reader_train.next_minibatch(minibatch_size, input_map)
            if X_data[X_real].num_samples == Z_data.shape[0]:
                batch_inputs = {X_real: X_data[X_real].data, 
                                Z: Z_data}
                D_trainer.train_minibatch(batch_inputs)

        # train the generator model for a single step
        Z_data = noise_sample(minibatch_size)
        batch_inputs = {Z: Z_data}
        G_trainer.train_minibatch(batch_inputs)

        G_trainer_loss = G_trainer.previous_minibatch_loss_average

    return Z, X_fake, G_trainer_loss


reader_train = create_reader(train_file, True, d_input_dim, label_dim=10)

G_input, G_output, G_trainer_loss = train(reader_train)

生成合成圖片

現在我們訓練好了這個模型,我們能通過簡單的給生成器傳入隨機噪音來創造合成圖片病展示他們。下面就生成的圖片裡的一些隨機樣本,要看其他照片,你只需要重新執行下面的程式碼。

def plot_images(images, subplot_shape):
    plt.style.use('ggplot')
    fig, axes = plt.subplots(*subplot_shape)
    for image, ax in zip(images, axes.flatten()):
        ax.imshow(image.reshape(28, 28), vmin = 0, vmax = 1.0, cmap = 'gray')
        ax.axis('off')
    plt.show()

noise = noise_sample(36)
images = G_output.eval({G_input: noise})
plot_images(images, subplot_shape =[6, 6])

image

大量的迭代會生成看起來更像MNIST資料集的圖片。一個更好的效果展示如下。
image

注意:要獲取真實世界的訊號需要通過大量的迭代。即使MNIST是一個非常簡單的資料,全連線網路在資料建模方面也非常有效。


歡迎掃碼關注我的微信公眾號獲取最新文章
image