TensorFlow-mnist訓練與測試
阿新 • • 發佈:2018-11-30
TensorFlow介紹性的概念就不闡述了,但是直接上程式碼好像又比較突兀!所以提前祝小夥伴們春節快樂!
硬體:NVIDIA-GTX1080
軟體:Windows7、python3.6.5、tensorflow-gpu-1.4.0
好了,上程式碼!程式碼通過分步解析,淺顯易懂!
第一步:匯入tensorflow
import tensorflow.examples.tutorials.mnist.input_data as input_data
import tensorflow as tf
第二步:設計權重與偏置
#############################define weights and bias######################## w_conv1 = tf.Variable(tf.truncated_normal([5, 5, 1, 32], stddev = 0.1)) b_conv1 = tf.Variable(tf.constant(0.1, shape = [32])) w_conv2 = tf.Variable(tf.truncated_normal([5, 5, 32, 64], stddev = 0.1)) b_conv2 = tf.Variable(tf.constant(0.1, shape = [64])) w_fc1 = tf.Variable(tf.truncated_normal([7*7*64, 1024], stddev = 0.1)) b_fc1 = tf.Variable(tf.constant(0.1, shape = [1024])) w_fc2 = tf.Variable(tf.truncated_normal([1*1*1024, 10], stddev = 0.1)) b_fc2 = tf.Variable(tf.constant(0.1, shape = [10])) #############################################################################
第三步:設計網路架構
##############################define model################################### #define input size x_ = tf.placeholder(tf.float32, [None, 28*28]) y_ = tf.placeholder(tf.float32, [None, 10]) #reshape input data x_input = tf.reshape(x_, [-1, 28, 28, 1]) #conv1 conv1 = tf.nn.conv2d(x_input, w_conv1, strides=[1,1,1,1], padding='SAME') + b_conv1 relu1 = tf.nn.relu(conv1) #pool1 pool1 = tf.nn.max_pool(relu1, ksize=[1,2,2,1],strides=[1,2,2,1], padding='SAME') #conv2 conv2 = tf.nn.conv2d(pool1, w_conv2, strides=[1,1,1,1], padding='SAME') + b_conv2 relu2 = tf.nn.relu(conv2) #pool2 pool2 = tf.nn.max_pool(relu2, ksize=[1,2,2,1],strides=[1,2,2,1], padding='SAME') #reshape pool2 for fc1 pool2_reshape = tf.reshape(pool2, [-1, 7*7*64]) #fc1 fc1 = tf.matmul(pool2_reshape, w_fc1) + b_fc1 relu3 = tf.nn.relu(fc1) #dropout keep_prob = tf.placeholder(tf.float32) fc1_dropout = tf.nn.dropout(relu3, keep_prob) #fc2 fc2 = tf.matmul(fc1_dropout, w_fc2) + b_fc2 #softmax y_out = tf.nn.softmax(fc2) #define loss cross_entropy_loss = -tf.reduce_sum(y_ * tf.log(y_out)) train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy_loss) #define accuracy correct_prediction = tf.equal(tf.argmax(y_out,1), tf.argmax(y_,1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32)) ###############################################################################
第四步:匯入資料,訓練網路與測試,測試準確率96.5%
####################################start model################################ #load data mnist = input_data.read_data_sets("MNIST_data/", one_hot = True) #define sess sess = tf.Session() sess.run(tf.global_variables_initializer()) #define save model saver = tf.train.Saver(max_to_keep=3) #train for step in range(1000): batch = mnist.train.next_batch(50) if step % 100 == 0: train_accuracy = accuracy.eval(session=sess, feed_dict={x_:batch[0], y_:batch[1], keep_prob:1.0}) print("step %d, train_accuracy %g" %(step, train_accuracy)) saver.save(sess,"MNIST_model/model.ckpt-" + str(step)) train_step.run(session=sess, feed_dict={x_:batch[0], y_:batch[1], keep_prob:0.5}) #test print("test accuracy %g" %accuracy.eval(session=sess, feed_dict={x_:mnist.test.images, y_:mnist.test.labels, keep_prob:1.0}))
第五步:結束語
程式碼有部分註釋,可以根據註釋進行理解,若有不理解的地方,可以百度或者谷歌一下。
打完收工!
任何問題請加唯一QQ2258205918(名稱samylee)!