1. 程式人生 > >Tensorflow 實戰Google深度學習框架 第五章 5.2.1Minister數字識別 原始碼

Tensorflow 實戰Google深度學習框架 第五章 5.2.1Minister數字識別 原始碼

  1 import os
  2 import tab
  3 import tensorflow as tf
  4 
  5 print "tensorflow 5.2 "
  6 
  7 from tensorflow.examples.tutorials.mnist import input_data
  8 
  9 '''
 10 mnist = input_data.read_data_sets("/asky/tensorflow/mnist_data",one_hot=True)
 11 print "-------------------------------------"
 12
print "Training data size: ",mnist.train.num_examples 13 print "-------------------------------------" 14 print "Validating data size: ",mnist.validation.num_examples 15 print "-------------------------------------" 16 print "Testing data size: " ,mnist.test.num_examples 17 print "-------------------------------------"
18 print "Example training data: ",mnist.train.images[0] 19 print "-------------------------------------" 20 print "Example training data label: ",mnist.train.labels[0] 21 22 batch_size = 100 23 xs,ys=mnist.train.next_batch(batch_size) 24 25 print "X shape:",xs.shape 26 27 print "Y shape:",ys.shape
28 29 30 print "Test Tezt" 31 ''' 32 33 INPUT_NODE = 784 34 OUTPUT_NODE = 10 35 36 LAYER1_NODE = 500 37 38 BATCH_SIZE = 100 39 40 LEARNING_RATE_BASE = 0.8 41 LEARNING_RATE_DECAY = 0.99 42 43 REGULARIZATION_RATE = 0.0001 44 TRAINING_STEPS = 30000 45 MOVING_AVERAGE_DECAY = 0.99 46 47 def inference(input_tensor,avg_class,weights1,biases1,weights2,biases2): 48 if avg_class == None: 49 layer1 = tf.nn.relu(tf.matmul(input_tensor,weights1)+biases1) 50 return tf.matmul(layer1,weights2)+biases2 51 else: 52 layer1 = tf.nn.relu( 53 tf.matmul(input_tensor,avg_class.average(weights1))+ 54 avg_class.average(biases1)) 55 return tf.matmul(layer1,avg_class.average(weights2))+avg_class.average(biases2) 56 57 def train(mnist): 58 x = tf.placeholder(tf.float32,[None,INPUT_NODE],name='x-input') 59 y_ = tf.placeholder(tf.float32,[None,OUTPUT_NODE],name='y-input') 60 weights1 = tf.Variable( 61 tf.truncated_normal([INPUT_NODE,LAYER1_NODE],stddev=0.1)) 62 biases1 = tf.Variable( tf.constant(0.1,shape=[LAYER1_NODE])) 63 64 weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE,OUTPUT_NODE],stddev=0.1)) 65 biases2 = tf.Variable(tf.constant(0.1,shape=[OUTPUT_NODE])) 66 67 y = inference(x,None,weights1,biases1,weights2,biases2) 68 69 global_step = tf.Variable(0,trainable=False) 70 71 variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,global_step) 72 73 variables_averages_op = variable_averages.apply(tf.trainable_variables()) 74 75 average_y = inference(x,variable_averages,weights1,biases1,weights2,biases2) 76 77 #cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(y, tf.argmax(y_, 1 )) 78 cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=tf.argmax(y_, 1), logits=y) 79 80 cross_entropy_mean = tf.reduce_mean(cross_entropy) 81 82 regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE) 83 84 regularization = regularizer(weights1) + regularizer(weights2) 85 86 loss = cross_entropy_mean + regularization 87 88 learning_rate = tf.train.exponential_decay( 89 LEARNING_RATE_BASE, 90 global_step, 91 mnist.train.num_examples/BATCH_SIZE, 92 LEARNING_RATE_DECAY 93 ) 94 95 train_step = tf.train.GradientDescentOptimizer(learning_rate).minimize(loss, global_step=global_step) 96 97 with tf.control_dependencies([train_step,variables_averages_op]): 98 train_op = tf.no_op(name='train') 99 100 correct_prediction = tf.equal(tf.argmax(average_y,1),tf.argmax(y_,1)) 101 accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) 102 103 with tf.Session() as sess: 104 tf.global_variables_initializer().run() 105 validate_feed = {x: mnist.validation.images, 106 y_: mnist.validation.labels} 107 test_feed = {x: mnist.test.images, y_: mnist.test.labels } 108 for i in range(TRAINING_STEPS): 109 if i % 1000 ==0: 110 validate_acc = sess.run(accuracy,feed_dict=validate_feed) 111 print ("After %d training step(s),validation accuracy " 112 "using average model is %g " %(i,validate_acc) ) 113 xs, ys = mnist.train.next_batch(BATCH_SIZE) 114 sess.run(train_op,feed_dict={x: xs , y_ : ys}) 115 116 test_acc = sess.run(accuracy,feed_dict=test_feed) 117 print ( "After %d training step(s),test accuracy using average " 118 "model is %g " % (TRAINING_STEPS , test_acc) ) 119 120 def main(argv=None) : 121 mnist = input_data.read_data_sets("/asky/tensorflow/mnist_data",one_hot=True) 122 train(mnist) 123 124 if __name__ == '__main__': 125 tf.app.run()