變分自編碼網路的實現
阿新 • • 發佈:2018-12-17
1、VAE跟Gan有點類似,都是可以通過一些輸入,生成一些樣本資料。不同點是VAE是假設在資料的分佈是服從正態分佈的,而GAN是沒有這個假設的,完全是由資料驅動,進行訓練得出規律的。
下面是變分自編碼網路的程式碼:
import numpy as np import tensorflow as tf import tensorflow.contrib as contrib from tensorflow.contrib.layers import fully_connected import tensorflow.examples.tutorials.mnist as mnist import functiontool as functiontool # 定義一些全域性變數 n_inputs = 28 * 28 n_hidden1 = 500 n_hidden2 = 500 n_hiddenmiddle = 30 n_hidden3 = n_hidden2 n_hidden4 = n_hidden1 n_outputs = n_inputs learning_rate = 0.001 Minst = mnist.input_data.read_data_sets("MNIST_data/") # 定義網路的結構 with contrib.framework.arg_scope([fully_connected], activation_fn=tf.nn.elu, weights_initializer= contrib.layers.variance_scaling_initializer()): X = tf.placeholder(dtype=tf.float32, shape=[None, n_inputs]) hidden1 = fully_connected(X, n_hidden1) hidden2 = fully_connected(hidden1, n_hidden2) hiddenmiddle_mean = fully_connected(hidden2, n_hiddenmiddle, activation_fn=None) hiddenMiddle_gamma = fully_connected(hidden2, n_hiddenmiddle, activation_fn=None) hiddenMiddel_sigmar = tf.exp(0.5 * hiddenMiddle_gamma) noise = tf.random_normal(tf.shape(hiddenMiddel_sigmar)) hiddemiddle = hiddenmiddle_mean + hiddenMiddel_sigmar * noise hidden3 = fully_connected(hiddemiddle, n_hidden3) hidden4 = fully_connected(hidden3, n_hidden4) logits = fully_connected(hidden4, n_outputs, activation_fn=None) outputs = tf.sigmoid(logits) # 定義損失函式 restruction_loss =tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=X, logits=logits)) latent_loss = 0.5 * tf.reduce_sum(tf.exp(hiddenMiddle_gamma) + tf.square(hiddenmiddle_mean) - 1 - hiddenMiddle_gamma) sum_loss = restruction_loss + latent_loss optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate) train_optimizer = optimizer.minimize(sum_loss) init = tf.global_variables_initializer() saver = tf.train.Saver() # 定義網路的訓練 n_epochs = 60 n_batch = 150 with tf.Session() as session: init.run() for i in range(n_epochs): batch_nums = Minst.train.num_examples // n_batch for batch_size in range(batch_nums): print("\r{}%".format(100 * batch_size // batch_nums), end="") X_trian, Y_train = Minst.train.next_batch(n_batch) session.run(train_optimizer, feed_dict={X: X_trian}) loss_val = sum_loss.eval(feed_dict={X: X_trian}) print("\rTrain loss:{}".format(loss_val)) saver.save(session, "weight/VaAuto.cpkt") test_rng = np.random.normal(size=(10, n_hiddenmiddle)) out_val = outputs.eval(feed_dict={hiddemiddle: test_rng}) functiontool.show_reconstructed_digits_old(out_val)
其畫圖的函式為:
def show_reconstructed_digits_old(outputs):
dimsize = outputs.shape[0]
plt.figure(figsize=(8, 50))
for i in range(outputs.shape[0]):
plt.subplot(outputs.shape[0], 1, i + 1)
plot_image(outputs[i])
plt.show()
得出的訓練結果是: