1. 程式人生 > >Keras實現autoencoder

Keras實現autoencoder

Keras實現autoencoder

Keras使我們搭建神經網路變得異常簡單,之前我們使用了Sequential來搭建LSTM:keras實現LSTM

我們要使用Keras的functional API搭建更加靈活的網路結構,比如說本文的autoencoder,關於autoencoder的介紹可以在這裡找到:deep autoencoder

 

現在我們就開始。

step 0 匯入需要的包

1 import keras
2 from keras.layers import Dense, Input
3 from keras.datasets import mnist
4 from keras.models import Model
5 import numpy as np

 step 1 資料預處理

這裡需要說明一下,匯入的原始資料shape為(60000,28,28),autoencoder使用(60000,28*28),而且autoencoder屬於無監督學習,所以只需要匯入x_train和x_test.

複製程式碼

1 (x_train, _), (x_test, _) = mnist.load_data()
2 x_train = x_train.astype('float32')/255.0
3 x_test = x_test.astype('float32')/255.0
4 #print(x_train.shape)
5 x_train = x_train.reshape(x_train.shape[0], -1)
6 x_test = x_test.reshape(x_test.shape[0], -1)
7 #print(x_train.shape)

複製程式碼

step 2 向圖片新增噪聲

新增噪聲是為了讓autoencoder更robust,不容易出現過擬合。

複製程式碼

1 #add random noise
2 x_train_nosiy = x_train + 0.3 * np.random.normal(loc=0., scale=1., size=x_train.shape)
3 x_test_nosiy = x_test + 0.3 * np.random.normal(loc=0, scale=1, size=x_test.shape)
4 x_train_nosiy = np.clip(x_train_nosiy, 0., 1.)
5 x_test_nosiy = np.clip(x_test_nosiy, 0, 1.)
6 print(x_train_nosiy.shape, x_test_nosiy.shape)

複製程式碼

step 3 搭建網路結構

分別構建encoded和decoded,然後將它們連結起來構成整個autoencoder。使用Model建模。

複製程式碼

1 #build autoencoder model
2 input_img = Input(shape=(28*28,))
3 encoded = Dense(500, activation='relu')(input_img)
4 decoded = Dense(784, activation='sigmoid')(encoded)
5 
6 autoencoder = Model(input=input_img, output=decoded)

複製程式碼

 step 4 compile

因為這裡是讓解壓後的圖片和原圖片做比較, loss使用的是binary_crossentropy。

1 autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
2 autoencoder.summary()

 

step 5 train

指定epochs,batch_size,可以使用validation_data,keras訓練的時候不會使用它,而是用來做模型評價。

autoencoder.fit(x_train_nosiy, x_train, epochs=20, batch_size=128, verbose=1, validation_data=(x_test, x_test))

 

step 6 對比一下解壓縮後的圖片和原圖片

複製程式碼

 1 %matplotlib inline
 2 import matplotlib.pyplot as plt
 3 
 4 #decoded test images
 5 decoded_img = autoencoder.predict(x_test_nosiy)
 6 
 7 n = 10
 8 plt.figure(figsize=(20, 4))
 9 for i in range(n):
10     #noisy data
11     ax = plt.subplot(3, n, i+1)
12     plt.imshow(x_test_nosiy[i].reshape(28, 28))
13     plt.gray()
14     ax.get_xaxis().set_visible(False)
15     ax.get_yaxis().set_visible(False)
16     #predict
17     ax = plt.subplot(3, n, i+1+n)
18     plt.imshow(decoded_img[i].reshape(28, 28))
19     plt.gray()
20     ax.get_yaxis().set_visible(False)
21     ax.get_xaxis().set_visible(False)
22     #original
23     ax = plt.subplot(3, n, i+1+2*n)
24     plt.imshow(x_test[i].reshape(28, 28))
25     plt.gray()
26     ax.get_yaxis().set_visible(False)
27     ax.get_xaxis().set_visible(False)
28 plt.show()

複製程式碼

 這樣的結果,你能分出哪個是壓縮解壓縮後的圖片哪個是原圖片嗎?

reference:

https://keras.io/getting-started/functional-api-guide/