實戰生成對抗網路[2]:生成手寫數字
在開始本文之前,讓我們先看看一則報道:
人民網訊 據英國廣播電視公司10月25日報道,由人工智慧創作的藝術作品以432000美元(約合300萬人民幣)的高價成功拍賣。
看起來一則不起眼的新聞,其實意義深遠,它意味著人們開始認可計算機創作的藝術價值,那些沾沾自喜認為不會被人工智慧取代的藝術家也要瑟瑟發抖了。
這幅由人工智慧創作的作品長啥樣,有啥過人之處?
嗯,以我這種外行人士看來,實在不怎麼樣,但這不意味著人工智慧不行。要知道,AlphaGo初出道時,也只敢挑戰一下樊麾這樣的二流棋手,接下來挑戰頂級棋手李世石,人類還能勉力一戰,等進化到AlphaGo Master,零封人類棋手。然而這還沒有完,AlphaGo Zero不再學習人類棋譜,完全通過自學,碾壓AlphaGo Master,對付人類棋手,更如我們捏死一隻螞蟻那麼容易。
所以說,儘管人工智慧創作的第一副作品如同鬼畫桃符,但其潛力無可限量。
那麼,接下來我們會探討如何創作出一幅名畫?No. No.
創作一副畫並不是那麼容易。這幅名為《埃德蒙·貝拉米肖像》的畫作是由巴黎一個名為“顯而易見”(Obvious)的藝術團體創作利用人工智慧技術創作而成,這幅作品是用演算法和15000幅從14世紀到20世紀的肖像畫資料製作而成。
我們還沒有那個條件去創作一副人工智慧的畫作,但我們可以先從基本的著手,生成手寫數字。手寫數字對於機器學習的同學來說,太熟悉不過了。既然是老朋友了,那讓我們開始吧!
首先回顧一下《實戰生成對抗網路[1]:簡介》這篇文章的內容,GAN由生成器和判別器組成。簡單起見,我們選擇簡單的二層神經網路來實現生成器和判別器。
生成器
實現生成器並不難,我們採取的全連線網路拓撲結構為:100 --> 128 --> 784,最後的輸出為784是因為MNIST資料集就是由28 x 28畫素的灰度影象組成。程式碼如下:
G_W1 = tf.Variable(initializer([100, 128]), name='G_W1')
G_b1 = tf.Variable(tf.zeros(shape=[128]), name='G_b1')
G_W2 = tf.Variable(initializer([128, 784]), name='G_W2')
G_b2 = tf.Variable(tf.zeros(shape=[784 ]), name='G_b2')
theta_G = [G_W1, G_W2, G_b1, G_b2]
def generator(z):
G_h1 = tf.nn.relu(tf.matmul(z, G_W1) + G_b1)
G_log_prob = tf.matmul(G_h1, G_W2) + G_b2
G_prob = tf.nn.sigmoid(G_log_prob)
return G_prob
複製程式碼
判別器
判別器正好相反,以MNIST影象作為輸入並返回一個代表真實影象的概率的標量,程式碼如下:
D_W1 = tf.Variable(initializer(shape=[784, 128]), name='D_W1')
D_b1 = tf.Variable(tf.zeros(shape=[128]), name='D_b1')
D_W2 = tf.Variable(initializer(shape=[128, 1]), name='D_W2')
D_b2 = tf.Variable(tf.zeros(shape=[1]), name="D_W2")
theta_D = [D_W1, D_W2, D_b1, D_b2]
def discriminator(x):
D_h1 = tf.nn.relu(tf.matmul(x, D_W1) + D_b1)
D_logit = tf.matmul(D_h1, D_W2) + D_b2
D_prob = tf.nn.sigmoid(D_logit)
return D_prob, D_logit
複製程式碼
訓練演算法
在論文arXiv: 1406.2661, 2014中給出了訓練演算法的虛擬碼:
TensorFlow中的優化器只能做最小化,因為為了最大化損失函式,我們在虛擬碼給出的損失函式前加上一個負號。
D_loss = -tf.reduce_mean(tf.log(D_real) + tf.log(1. - D_fake))
G_loss = -tf.reduce_mean(tf.log(D_fake))
複製程式碼
接下來定義優化器:
# 僅更新D(X)的引數, var_list=theta_D
D_solver = tf.train.AdamOptimizer().minimize(D_loss, var_list=theta_D)
# 僅更新G(X)的引數, var_list=theta_G
G_solver = tf.train.AdamOptimizer().minimize(G_loss, var_list=theta_G)
複製程式碼
最後進行迭代,更新引數:
for it in range(60000):
X_mb, _ = mnist.train.next_batch(mb_size)
_, D_loss_curr = sess.run([D_solver, D_loss], feed_dict={X: X_mb, Z: sample_Z(mb_size, Z_dim)})
_, G_loss_curr = sess.run([G_solver, G_loss], feed_dict={Z: sample_Z(mb_size, Z_dim)})
複製程式碼
整個流程下來,其實和之前的深度學習演算法差不多,非常容易理解。演算法是不是有效果呢?我們可以將迭代過程中生成的手寫數字顯示出來:
嗯,結果雖然有點差強人意,但差不多是手寫數字的字形,而且隨著迭代,越來越接近手寫數字,可以說GAN演算法還是有效的。
小結
一個簡單的GAN網路就這麼幾行程式碼就能搞定,看樣子生成一副畫也沒有什麼難的。先不要這麼樂觀,其實,GAN網路中的坑還是不少,比如在迭代過程中,就出現過如下提示:
Iter: 9000
D loss: nan
G_loss: nan
複製程式碼
從程式碼中我們可以看出,GAN網路依然採用的梯度下降法來迭代求解引數。梯度下降的啟動會選擇一個減小所定義問題損失的方向,但是我們並沒有一個辦法來確保利用GAN網路可以進入納什均衡的狀態,這是一個高維度的非凸優化目標。網路試圖在接下來的步驟中最小化非凸優化目標,最終有可能導致進入振盪而不是收斂到底層正式目標。
另外還有模型坍塌、計數、角度以及全域性結構方面的問題,要解決這些問題,需要使用一些特殊的技巧和方法,後面我們深入各種GAN模型時將會探討。
本文完整的程式碼請參考: github.com/mogoweb/aie…