1. 程式人生 > >單向RNN和雙向RNN在mnist資料集上的分類實驗

單向RNN和雙向RNN在mnist資料集上的分類實驗

RNN用於影象分類思路很奇特,不明覺厲,具體可以參考相關論文,rnn和birnn的實驗:

#!/usr/bin/env python
# -*- coding: utf-8 -*-
# created by fhqplzj on 2017/06/19 下午10:28
from __future__ import print_function

import tensorflow as tf
from tensorflow.contrib import rnn
from tensorflow.examples.tutorials.mnist import input_data

mnist = input_data.read_data_sets('/Users/fhqplzj/github/TensorFlow-Examples/examples/3_NeuralNetworks/data',
                                  one_hot=True)
learning_rate = 0.001
training_iters = 100000
batch_size = 128
display_step = 10
n_input = 28
n_steps = 28
n_hidden = 128
n_classes = 10
x = tf.placeholder(tf.float32, [None, n_steps, n_input])
y = tf.placeholder(tf.float32, [None, n_classes])
weights = {
    'out1': tf.Variable(tf.random_normal([n_hidden, n_classes])),
    'out2': tf.Variable(tf.random_normal([2 * n_hidden, n_classes]))
}
biases = {
    'out': tf.Variable(tf.random_normal([n_classes]))
}


def RNN(x, weights, biases):
    x = tf.unstack(x, n_steps, 1)
    lstm_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
    outputs, states = rnn.static_rnn(lstm_cell, x, dtype=tf.float32)
    return tf.matmul(outputs[-1], weights['out1']) + biases['out']


def BiRNN(x, weights, biases):
    x = tf.unstack(x, n_steps, 1)
    lstm_fw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
    lstm_bw_cell = rnn.BasicLSTMCell(n_hidden, forget_bias=1.0)
    outputs, _, _ = rnn.static_bidirectional_rnn(lstm_fw_cell, lstm_bw_cell, x, dtype=tf.float32)
    return tf.matmul(outputs[-1], weights['out2']) + biases['out']


for func in (RNN, BiRNN):
    print(func.func_name.center(100, '+'))
    pred = func(x, weights, biases)
    cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
    optimizer = tf.train.AdamOptimizer(learning_rate=learning_rate).minimize(cost)
    correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
    accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        step = 1
        while step * batch_size < training_iters:
            batch_x, batch_y = mnist.train.next_batch(batch_size)
            batch_x = batch_x.reshape((-1, n_steps, n_input))
            sess.run(optimizer, feed_dict={
                x: batch_x,
                y: batch_y
            })
            if step % display_step == 0:
                acc = sess.run(accuracy, feed_dict={
                    x: batch_x,
                    y: batch_y
                })
                loss = sess.run(cost, feed_dict={
                    x: batch_x,
                    y: batch_y
                })
                print('acc={:.6f},cost={:.6f}'.format(acc, loss))
            step += 1
        print('Optimization Finished!')
        total_len = 128
        test_x, test_y = mnist.test.next_batch(total_len)
        test_x = test_x.reshape((-1, n_steps, n_input))
        print(sess.run(accuracy, feed_dict={
            x: test_x,
            y: test_y
        }))