單向RNN和雙向RNN在mnist資料集上的分類實驗
阿新 • • 發佈:2019-02-14
RNN用於影象分類思路很奇特,不明覺厲,具體可以參考相關論文,rnn和birnn的實驗:
#!/usr/bin/env python # -*- coding: utf-8 -*- # created by fhqplzj on 2017/06/19 下午10:28 from __future__ import print_function import tensorflow as tf from tensorflow.contrib import rnn from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('/Users/fhqplzj/github/TensorFlow-Examples/examples/3_NeuralNetworks/data', one_hot=True) learning_rate = 0.001 training_iters = 100000 batch_size = 128 display_step = 10 n_input = 28 n_steps = 28 n_hidden = 128 n_classes = 10 x = tf.placeholder(tf.float32, [None, n_steps, n_input]) y = tf.placeholder(tf.float32, [None, n_classes]) weights = { 'out1': tf.Variable(tf.random_normal([n_hidden, n_classes])), 'out2': tf.Variable(tf.random_normal([2 * n_hidden, n_classes])) } biases = { 'out': tf.Variable(tf.random_normal([n_classes])) } def RNN(x, weights, biases): x = tf.unstack(x, n_steps, 1) lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0) outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32) return tf.matmul(outputs[-1], weights['out1']) + biases['out'] def BiRNN(x, weights, biases): x = tf.unstack(x, n_steps, 1) lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0) lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0) outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32) return tf.matmul(outputs[-1], weights['out2']) + biases['out'] for func in (RNN, BiRNN): print(func.func_name.center(100, '+')) pred = func(x, weights, biases) cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost) correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) step = 1 while step * batch_size < training_iters: batch_x, batch_y = mnist.train.next_batch(batch_size) batch_x = batch_x.reshape((-1, n_steps, n_input)) sess.run(optimizer, feed_dict={ x: batch_x, y: batch_y }) if step % display_step == 0: acc = sess.run(accuracy, feed_dict={ x: batch_x, y: batch_y }) loss = sess.run(cost, feed_dict={ x: batch_x, y: batch_y }) print('acc={:.6f},cost={:.6f}'.format(acc, loss)) step += 1 print('Optimization Finished!') total_len = 128 test_x, test_y = mnist.test.next_batch(total_len) test_x = test_x.reshape((-1, n_steps, n_input)) print(sess.run(accuracy, feed_dict={ x: test_x, y: test_y }))