rnn-tf代碼詳解
阿新 • • 發佈:2019-03-08
二維 enc cte read 每一個 序列 tor ssi basic
手寫數字識別經典案例,旨在熟悉RNN結構,掌握tf編寫RNN的方法。
# coding:utf-8 import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data mnist=input_data.read_data_sets("./data",one_hot=True) # 常規參數 train_rate=0.001 train_step=10000 batch_size=1280 display_step=100 # rnn參數 frame_size=28 # 輸入特征數 sequence_length=28 # 輸入個數 hidden_num=100 # 隱層神經元個數 n_classes=10 # 定義輸入,輸出 # 此處輸入格式是樣本數*特征數,特征是把圖片拉成一維的,當然一維還是二維自己定,改成相應的代碼就行了 x=tf.placeholder(dtype=tf.float32,shape=[None,sequence_length*frame_size],name="inputx") y=tf.placeholder(dtype=tf.float32,shape=[None,n_classes],name="expected_y") # 定義權值 # 註意權值設定只設定v, u和w無需設定 weights=tf.Variable(tf.truncated_normal(shape=[hidden_num,n_classes])) # 全連接層權重 bias=tf.Variable(tf.zeros(shape=[n_classes])) def RNN(x,weights,bias): x=tf.reshape(x,shape=[-1,sequence_length,frame_size]) # 3維 rnn_cell=tf.nn.rnn_cell.BasicRNNCell(hidden_num) init_state=tf.zeros(shape=[batch_size,rnn_cell.state_size]) # 其實這是一個深度RNN網絡,對於每一個長度為n的序列[x1,x2,x3,...,xn]的每一個xi,都會在深度方向跑一遍RNN,跑上hidden_num個隱層單元 output,states=tf.nn.dynamic_rnn(rnn_cell,x,dtype=tf.float32) return tf.nn.softmax(tf.matmul(output[:,-1,:],weights)+bias,1) # y=softmax(vh+c) predy=RNN(x,weights,bias) # 以下所有神經網絡大同小異 cost=tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(logits=predy,labels=y)) train=tf.train.AdamOptimizer(train_rate).minimize(cost) correct_pred=tf.equal(tf.argmax(predy,1),tf.argmax(y,1)) accuracy=tf.reduce_mean(tf.to_float(correct_pred)) sess=tf.Session() sess.run(tf.global_variables_initializer()) step=1 testx,testy=mnist.test.next_batch(batch_size) while step<train_step: batch_x,batch_y=mnist.train.next_batch(batch_size) _loss,__=sess.run([cost,train],feed_dict={x:batch_x,y:batch_y}) if step % display_step ==0: acc,loss=sess.run([accuracy,cost],feed_dict={x:testx,y:testy}) print(step,acc,loss) step+=1
這是最簡單的RNN,後面還有非常非常非常復雜的在等你。
rnn-tf代碼詳解