1. 程式人生 > >Tensor flow實戰之LSTM

Tensor flow實戰之LSTM

import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data',one_hot=True)
import numpy as np

train_num = 1000
batch_num = 150

def setnumber():
    train_num = int(input('請輸入訓練次數:'))
    batch_num = int(input('請輸入批次:'))

class Net:
    def __init__(self):
        self.x = tf.placeholder(dtype=tf.float32,shape=[None,784])#[100,28*28]->[100*28,28]
        self.y = tf.placeholder(dtype=tf.float32,shape=[None,10])
        self.in_w = tf.Variable(tf.truncated_normal([28,batch_num+28],dtype=tf.float32,stddev=0.1))
        self.in_b = tf.Variable(tf.zeros([batch_num+28]))
        self.out_w = tf.Variable(tf.truncated_normal([batch_num+28,10],dtype=tf.float32,stddev=0.1))
        self.out_b = tf.Variable(tf.zeros([10]))
    def forward(self):
        self.y1 = tf.reshape(self.x,[-1,28])
        self.y2 = tf.nn.relu(tf.matmul(self.y1,self.in_w)+self.in_b)#[100*28,128]
        self.y3 = tf.reshape(self.y2,[-1,28,batch_num+28])
        lstm_cell = tf.contrib.rnn.BasicLSTMCell(batch_num+28)
        init_state = lstm_cell.zero_state(batch_num,dtype=tf.float32)
        outputs,final_state = tf.nn.dynamic_rnn(lstm_cell,self.y3,initial_state=init_state,time_major=False)
        self.y4 = outputs[:,-1,:]
        self.output = tf.nn.softmax(tf.matmul(self.y4,self.out_w)+self.out_b)
    def backward(self):
        self.loss = tf.reduce_mean((self.output-self.y)**2)
        self.opt = tf.train.AdamOptimizer().minimize(self.loss)

if __name__ == '__main__':
    setnumber()
    net = Net()
    net.forward()
    net.backward()
    init = tf.global_variables_initializer()
    with tf.Session() as sess:
        sess.run(init)
        for i in range(train_num):
            xs, ys = mnist.train.next_batch(batch_num)
            loss, _ = sess.run([net.loss, net.opt], feed_dict={net.x: xs, net.y: ys})
            if i % 100 == 0:
                test_xs, test_ys = mnist.test.next_batch(batch_num)
                tset_out = sess.run(net.output, feed_dict={net.x: test_xs})
                y = np.argmax(test_ys, axis=1)
                y_hat = np.argmax(tset_out, axis=1)
                print(np.mean(np.array(y == y_hat, dtype=np.float32)))