TensorFlow(五)——MNIST分類值RNN
阿新 • • 發佈:2018-12-09
import input_data import tensorflow as tf import numpy as np mnist = input_data.read_data_sets('data/', one_hot=True) #設定訓練超引數 lr = 0.001 training_iters = 100000 batch_size = 128 #設定神經網路引數 n_inputs = 28 #輸入層的n n_steps = 28 n_hidden_units = 128 n_classes = 10 #輸入資料佔位符 x = tf.placeholder(tf.float32, [None, n_steps, n_inputs]) y = tf.placeholder(tf.float32, [None, n_classes]) #定義權重 weights = { 'in': tf.Variable(tf.random_normal([n_inputs, n_hidden_units])), 'out': tf.Variable(tf.random_normal([n_hidden_units, n_classes])) } biases = { 'in': tf.Variable(tf.constant(0.1, shape=[n_hidden_units, ])), 'out': tf.Variable(tf.constant(0.1, shape=[n_classes, ])) } #定義RNN模型 def RNN(X, weights, biases): #轉化輸入的X==>(128 batch * 28 steps, 28 inputs) X = tf.reshape(X, [-1, n_inputs]) #進入隱藏層 X_in = tf.matmul(X, weights['in']) + biases['in'] X_in = tf.reshape(X_in, [-1, n_steps, n_hidden_units]) #採用LSTM lstm_cell = tf.contrib.rnn.BasicLSTMCell(n_hidden_units, forget_bias=1.0, state_is_tuple=True) init_state = lstm_cell.zero_state(batch_size, dtype=tf.float32) outputs, final_state = tf.nn.dynamic_rnn(lstm_cell, X_in, initial_state=init_state, time_major=False) results = tf.matmul(final_state[1], weights['out']) + biases['out'] return results #定義損失函式和優化器 pred = RNN(x, weights, biases) cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=pred, labels=y)) train_op = tf.train.AdamOptimizer(lr).minimize(cost) #定義模型預測結果和評價方法 correct_pred = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32)) with tf.Session() as sess: sess.run(tf.global_variables_initializer()) step = 0 while step * batch_size < training_iters: batch_xs, batch_ys = mnist.train.next_batch(batch_size) batch_xs = batch_xs.reshape([batch_size, n_steps, n_inputs]) sess.run([train_op], feed_dict={ x: batch_xs, y: batch_ys, }) if step % 20 == 0: print (sess.run(accuracy, feed_dict={ x: batch_xs, y: batch_ys, })) step += 1
結果:
0.171875 0.671875 0.8046875 0.8203125 0.8203125 0.8671875 0.8515625 0.890625 0.8984375 0.859375 0.921875 0.9375 0.8671875 0.9296875 0.9296875 0.9453125 0.9296875 0.984375 0.9140625 0.9609375 0.96875 0.9765625 0.9609375 0.96875 0.9453125 0.9609375 0.9453125 0.9609375 0.9609375 0.96875 0.953125 0.96875 0.9765625 0.9609375 0.96875 0.953125 0.984375 0.9765625 0.9453125 0.9453125