1. 程式人生 > >簡單易懂的自動編碼器

簡單易懂的自動編碼器

作者:葉虎

編輯:田旭

引言

自動編碼器是一種無監督的神經網路模型,它可以學習到輸入資料的隱含特徵,這稱為編碼(coding),同時用學習到的新特徵可以重構出原始輸入資料,稱之為解碼(decoding)。從直觀上來看,自動編碼器可以用於特徵降維,類似主成分分析PCA,但是其相比PCA其效能更強,這是由於神經網路模型可以提取更有效的新特徵。除了進行特徵降維,自動編碼器學習到的新特徵可以送入有監督學習模型中,所以自動編碼器可以起到特徵提取器的作用。作為無監督學習模型,自動編碼器還可以用於生成與訓練樣本不同的新資料,這樣自動編碼器(變分自動編碼器,VariationalAutoencoders)就是生成式模型。


本文將會講述自動編碼器的基本原理以及常用的自動編碼器模型:堆疊自動編碼器(StackedAutoencoder)。後序的文章會講解自動編碼器其他模型:去噪自動編碼器(DenoisingAutoencoder),稀疏自動編碼器(SparseAutoencoder)以及變分自動編碼器。所有的模型都會使用Tensorflow進行程式設計實現。

自動編碼器原理

自動編碼器的基本結構如圖1所示,包括編碼和解碼兩個過程:

640?wx_fmt=png&wxfrom=5&wx_lazy=1

圖1自動編碼器的編碼與解碼

自動編碼器是將輸入0?wx_fmt=png進行編碼,得到新的特徵0?wx_fmt=png,並且希望原始的輸入0?wx_fmt=png能夠從新的特徵0?wx_fmt=png重構出來。編碼過程如下:

0?wx_fmt=png

可以看到,和神經網路結構一樣,其編碼就是線性組合之後加上非線性的啟用函式。如果沒有非線性的包裝,那麼自動編碼器就和普通的PCA沒有本質區別了。利用新的特徵0?wx_fmt=png

,可以對輸入0?wx_fmt=png重構,即解碼過程:

0?wx_fmt=png

我們希望重構出的0?wx_fmt=png0?wx_fmt=png儘可能一致,可以採用最小化負對數似然的損失函式來訓練這個模型:

0?wx_fmt=png

對於高斯分佈的資料,採用均方誤差就好,而對於伯努利分佈可以採用交叉熵,這個是可以根據似然函式推匯出來的。一般情況下,我們會對自動編碼器加上一些限制,常用的是使0?wx_fmt=png,這稱為繫結權重(tiedweights),本文所有的自動編碼器都加上這個限制。有時候,我們還會給自動編碼器加上更多的約束條件,去噪自動編碼器以及稀疏自動編碼器就屬於這種情況,因為大部分時候單純地重構原始輸入並沒有什麼意義,我們希望自動編碼器在近似重構原始輸入的情況下能夠捕捉到原始輸入更有價值的資訊。

堆疊自動編碼器

前面我們講了自動編碼器的原理,不過所展示的自動編碼器只是簡答的含有一層,其實可以採用更深層的架構,這就是堆疊自動編碼器或者深度自動編碼器,本質上就是增加中間特徵層數。這裡我們以MNIST資料為例來說明自動編碼器,建立兩個隱含層的自動編碼器,如圖2所示:

0?wx_fmt=png

圖2堆疊自動編碼器架構

對於MNIST來說,其輸入是28*28=784維度的特徵,這裡使用了兩個隱含層其維度分別為300和150,可以看到是不斷降低特徵的維度了。得到的最終編碼為150維度的特徵,使用這個特徵進行反向重構得到重建的特徵,我們希望重建特徵和原始特徵儘量相同。由於MNIST是0,1量,可以採用交叉熵作為損失函式,TF的程式碼核心程式碼如下:

(左右滑動,檢視完整程式碼)

n_inputs = 28*28
n_hidden1 = 300
n_hidden2 = 150

# 定義輸入佔位符:不需要y
X = tf.placeholder(tf.float32, [None, n_inputs])

# 定義訓練引數
initializer = tf.contrib.layers.variance_scaling_initializer()
W1 = tf.Variable(initializer([n_inputs, n_hidden1]), name="W1")
b1 = tf.Variable(tf.zeros([n_hidden1,]), name="b1")
W2 = tf.Variable(initializer([n_hidden1, n_hidden2]), name="W2")
b2 = tf.Variable(tf.zeros([n_hidden2,]), name="b2")
W3 = tf.transpose(W2, name="W3")
b3 = tf.Variable(tf.zeros([n_hidden1,]), name="b3")
W4 = tf.transpose(W1, name="W4")
b4 = tf.Variable(tf.zeros([n_inputs,]), name="b4")

# 構建模型
h1 = tf.nn.sigmoid(tf.nn.xw_plus_b(X, W1, b1))
h2 = tf.nn.sigmoid(tf.nn.xw_plus_b(h1, W2, b2))
h3 = tf.nn.sigmoid(tf.nn.xw_plus_b(h2, W3, b3))
outputs = tf.nn.sigmoid(tf.nn.xw_plus_b(h3, W4, b4))

# 定義loss
loss = -tf.reduce_mean(tf.reduce_sum(X * tf.log(outputs) +
                   (1 - X) * tf.log(1 - outputs), axis=1))
train_op = tf.train.AdamOptimizer(1e-02).minimize(loss)

當訓練這個模型後,你可以將原始MNIST的數字手寫體與重構出的手寫體做個比較,如圖3所示,上面是原始圖片,而下面是重構圖片,基本上沒有差別了。儘管我們將維度從784降低到了150,得到的新特徵還是抓取了原始特徵的核心資訊。

0?wx_fmt=png

圖3原始圖片(上)與重構圖片對比(下)

有一點,上面的訓練過程是一下子訓練完成的,其實對於堆疊編碼器來說,有時候會採用逐層訓練方式。直白點就是一層一層地訓練:先訓練X->h1的編碼,使h1->X的重構誤差最小化;之後再訓練h1->h2的編碼,使h2->h1的重構誤差最小化。其實現程式碼如下:

(左右滑動,檢視完整程式碼)

# 構建模型
h1 = tf.nn.sigmoid(tf.nn.xw_plus_b(X, W1, b1))
h1_recon = tf.nn.sigmoid(tf.nn.xw_plus_b(h1, W4, b4))
h2 = tf.nn.sigmoid(tf.nn.xw_plus_b(h1, W2, b2))
h2_recon = tf.nn.sigmoid(tf.nn.xw_plus_b(h2, W3, b3))
outputs = tf.nn.sigmoid(tf.nn.xw_plus_b(h2_recon, W4, b4))

learning_rate = 1e-02
# X->h1
with tf.name_scope("layer1"):
   layer1_loss = -tf.reduce_mean(tf.reduce_sum(X * tf.log(h1_recon) +
                           (1-X) * tf.log(1-h1_recon), axis=1))
   layer1_train_op = tf.train.AdamOptimizer(learning_rate).minimize(layer1_loss,
                                           var_list=[W1, b1, b4])

# h1->h2
with tf.name_scope("layer2"):
   layer2_loss = -tf.reduce_mean(tf.reduce_sum(h1 * tf.log(h2_recon) +
                 (1 - h1) * tf.log(1 - h2_recon), axis=1))
   layer2_train_op = tf.train.AdamOptimizer(learning_rate).minimize(layer2_loss,
                     var_list=[W2, b2, b3])

分層訓練之後,最終得到如圖4所示的對比結果,結果還是不錯的。

0?wx_fmt=png

圖4原始圖片(上)與重構圖片對比(下)

小結

自動編碼器應該是最通俗易懂的無監督神經網路模型,這裡我們介紹了其基本原理及堆疊自動編碼器。後序會介紹更多的自動編碼器模型。

參考文獻

1. Hands-On Machine Learning withScikit-Learn and TensorFlow, Aurélien Géron, 2017.

2. Deep Learning Tutorials:AutoEncoders, Denoising Autoencoders.

http://deeplearning.net/tutorial/dA.html#daa

3. Learning deep architectures for AI, Foundations and Trends inMachine Learning, Y. Bengio, 2009.

掃描個人微訊號,

拉你進機器學習大牛群。

福利滿滿,名額已不多…

0?wx_fmt=jpeg

80%的AI從業者已關注我們微信公眾號

0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif

0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif 0?wx_fmt=gif