1. 程式人生 > >TensorFlow(五)——MNIST分類值RNN

TensorFlow(五)——MNIST分類值RNN

import input_data
import tensorflow as tf
import numpy as np

mnist = input_data.read_data_sets('data/', one_hot=True)

#設定訓練超引數
lr = 0.001
training_iters = 100000
batch_size = 128

#設定神經網路引數
n_inputs = 28 #輸入層的n
n_steps = 28
n_hidden_units = 128
n_classes = 10

#輸入資料佔位符
x = tf.placeholder(tf.float32, [None, n_steps, n_inputs])
y = tf.placeholder(tf.float32, [None, n_classes])

#定義權重
weights = {
    'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])),
    'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes]))
}
biases = {
    'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])),
    'out': tf.Variable(tf.constant(0.1, shape=[n_classes, ]))
}

#定義RNN模型
def RNN(X, weights, biases):
    
    #轉化輸入的X==>(128 batch * 28 steps, 28 inputs)
    X = tf.reshape(X, [-1, n_inputs])
    
    #進入隱藏層
    X_in = tf.matmul(X, weights['in']) + biases['in']
    X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units])
    
    #採用LSTM
    lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, 
                                            state_is_tuple=True)
    init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32)
    
    outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False)
    
    results = tf.matmul(final_state[1], weights['out']) + biases['out']
    return results

#定義損失函式和優化器
pred = RNN(x, weights, biases)
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y))
train_op = tf.train.AdamOptimizer(lr).minimize(cost)

#定義模型預測結果和評價方法
correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1))
accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

with tf.Session() as sess:
    sess.run(tf.global_variables_initializer())
    step = 0
    while step * batch_size < training_iters:
        batch_xs, batch_ys = mnist.train.next_batch(batch_size)
        batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs])
        sess.run([train_op], feed_dict={
            x: batch_xs,
            y: batch_ys,
        })
        if step % 20 == 0:
            print (sess.run(accuracy, feed_dict={
                x: batch_xs,
                y: batch_ys,
            }))
        step += 1

結果:

0.171875
0.671875
0.8046875
0.8203125
0.8203125
0.8671875
0.8515625
0.890625
0.8984375
0.859375
0.921875
0.9375
0.8671875
0.9296875
0.9296875
0.9453125
0.9296875
0.984375
0.9140625
0.9609375
0.96875
0.9765625
0.9609375
0.96875
0.9453125
0.9609375
0.9453125
0.9609375
0.9609375
0.96875
0.953125
0.96875
0.9765625
0.9609375
0.96875
0.953125
0.984375
0.9765625
0.9453125
0.9453125