Tensor flow實戰之LSTM
阿新 • • 發佈:2018-12-31
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data',one_hot=True) import numpy as np train_num = 1000 batch_num = 150 def setnumber(): train_num = int(input('請輸入訓練次數:')) batch_num = int(input('請輸入批次:')) class Net: def __init__(self): self.x = tf.placeholder(dtype=tf.float32,shape=[None,784])#[100,28*28]->[100*28,28] self.y = tf.placeholder(dtype=tf.float32,shape=[None,10]) self.in_w = tf.Variable(tf.truncated_normal([28,batch_num+28],dtype=tf.float32,stddev=0.1)) self.in_b = tf.Variable(tf.zeros([batch_num+28])) self.out_w = tf.Variable(tf.truncated_normal([batch_num+28,10],dtype=tf.float32,stddev=0.1)) self.out_b = tf.Variable(tf.zeros([10])) def forward(self): self.y1 = tf.reshape(self.x,[-1,28]) self.y2 = tf.nn.relu(tf.matmul(self.y1,self.in_w)+self.in_b)#[100*28,128] self.y3 = tf.reshape(self.y2,[-1,28,batch_num+28]) lstm_cell = tf.contrib.rnn.BasicLSTMCell(batch_num+28) init_state = lstm_cell.zero_state(batch_num,dtype=tf.float32) outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,self.y3,initial_state=init_state,time_major=False) self.y4 = outputs[:,-1,:] self.output = tf.nn.softmax(tf.matmul(self.y4,self.out_w)+self.out_b) def backward(self): self.loss = tf.reduce_mean((self.output-self.y)**2) self.opt = tf.train.AdamOptimizer().minimize(self.loss) if __name__ == '__main__': setnumber() net = Net() net.forward() net.backward() init = tf.global_variables_initializer() with tf.Session() as sess: sess.run(init) for i in range(train_num): xs, ys = mnist.train.next_batch(batch_num) loss, _ = sess.run([net.loss, net.opt], feed_dict={net.x: xs, net.y: ys}) if i % 100 == 0: test_xs, test_ys = mnist.test.next_batch(batch_num) tset_out = sess.run(net.output, feed_dict={net.x: test_xs}) y = np.argmax(test_ys, axis=1) y_hat = np.argmax(tset_out, axis=1) print(np.mean(np.array(y == y_hat, dtype=np.float32)))