TensorFlow -- 訓練MNIST資料集
阿新 • • 發佈:2019-02-11
# -*- coding:utf-8 -*- import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data import pylab # 下載並解壓資料 mnist = input_data.read_data_sets('MNIST_data/', one_hot=True) # 重置圖 tf.reset_default_graph() # 定義佔位符 # 資料集的維度是28*28=874 x = tf.placeholder(tf.float32, [None, 784]) # 共10個類別 y = tf.placeholder(tf.float32, [None, 10]) # 定義學習引數 W = tf.Variable(tf.random_normal(([784, 10]))) b = tf.Variable(tf.zeros([10])) # 正向傳播 # softmax分類器 pred = tf.nn.softmax(tf.matmul(x, W) + b) # 反向結構 # 損失函式 cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=1)) # 定義引數 learning_rate = 0.01 # 使用梯度下降優化器 optimizer = tf.train.GradientDescentOptimizer(learning_rate).minimize(cost) # 迭代次數 training_epochs = 25 batch_size = 100 display_step = 1 saver = tf.train.Saver() model_path = 'log/521model.ckpt' # 啟動session with tf.Session() as sess: sess.run(tf.global_variables_initializer()) # 啟動迴圈開始訓練 for epoch in range(training_epochs): avg_cost = 0 # total_batch=550 total_batch = int(mnist.train.num_examples/batch_size) # 迴圈所有的資料集 for i in range(total_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) # 執行優化器 _, c = sess.run([optimizer, cost], feed_dict={x: batch_xs, y: batch_ys}) # 計算平均的loss值 avg_cost += c / total_batch # 顯示訓練中的詳細資訊 if (epoch+1) % display_step == 0: print('Epochs:', '%04d' % (epoch+1), 'cost:', '{:.9f}'.format(avg_cost)) print('Finished!') # 測試模型 correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) # 計算精確度 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print('Accurary:', accuracy.eval({x: mnist.test.images, y: mnist.test.labels})) # 儲存模型 save_path = saver.save(sess, model_path) print('Model saved in file: %s' % save_path) # 讀取模型 print('Starting loading model') with tf.Session() as sess: # 初始化變數 sess.run(tf.global_variables_initializer()) # 恢復模型變數 saver.restore(sess, model_path) # 測試模型 correct_prediction = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print('Accuracy:', accuracy.eval({x: mnist.test.images, y:mnist.test.labels})) output = tf.argmax(pred, 1) batch_xs, batch_ys = mnist.train.next_batch(2) outputval, predv = sess.run([output, pred], feed_dict={x: batch_xs, y:batch_ys}) print(outputval, predv, batch_ys) im = batch_xs[0] im = im.reshape(-1, 28) pylab.imshow(im) pylab.show() im = batch_xs[1] im = im.reshape(-1, 28) pylab.imshow(im) pylab.show()