利用softmax函式對mnist資料集簡單分類
阿新 • • 發佈:2018-12-19
mnist資料集的特點
- 每一張圖片包含28**28個畫素,我們把這一個陣列展開成一個向量,長度是28*28=784。因此在 MNIST訓練資料集中mnist.train.images 是一個形狀為 [60000, 784] 的張量,第一個維度數字用 來索引圖片,第二個維度數字用來索引每張圖片中的畫素點。圖片裡的某個畫素的強度值介於0-1 之間。
- MNIST資料集的標籤是介於0-9的數字,我們要把標籤轉化為“one-hot vectors”。一個onehot向量除了某一位數字是1以外,其餘維度數字都是0,比如標籤0將表示為([1,0,0,0,0,0,0,0,0,0]) ,標籤3將表示為([0,0,0,1,0,0,0,0,0,0]) 。
- 因此, mnist.train.labels 是一個 [60000, 10] 的數字矩陣。
例如,下面這幅圖,代表的數字為5042
softmax函式:
- 我們知道MNIST的結果是0-9,我們的模型可能推測出一張圖片是數字9的概率是80%,是數字8 的概率是10%,然後其他數字的概率更小,總體概率加起來等於1。這是一個使用softmax迴歸模型的經典案例。softmax模型可以用來給不同的物件分配概率。
程式如下:
import tensorflow as tf from tensorflow.examples.tutorials.mnist import input_data # 載入資料集 mnist = input_data.read_data_sets("MNIST_data", one_hot=True) # 定義批次batch_size,一次性放入100張圖片 batch_size = 100 # 計算一個有多少個批次 n_batch = mnist.train.num_examples // batch_size # 定義兩個placeholder x = tf.placeholder(tf.float32, [None, 784]) y = tf.placeholder(tf.float32, [None, 10]) # 建立一個簡單的神經網路 W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros(10)) prediction = tf.nn.softmax(tf.matmul(x, W) + b) # 二次代價函式 loss = tf.reduce_mean(tf.square(y - prediction)) # 使用梯度下降法 train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss) # 初始化變數 init = tf.initialize_all_variables() # 預測的結果 # tf.argmax()返回最大值所在的列 # 結果存放在一個bool型列表中 correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(prediction, 1)) # 求準確率 accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) with tf.Session() as sess: sess.run(init) for epoch in range(21): for batch in range(n_batch): batch_xs, batch_ys = mnist.train.next_batch(batch_size) sess.run(train_step, feed_dict={x: batch_xs, y: batch_ys}) acc = sess.run(accuracy, feed_dict={x:mnist.test.images, y: mnist.test.labels}) print("Iter" + str(epoch) + ",Testing Accuracy" + str(acc))
執行結果如下:
Iter0,Testing Accuracy0.7488 Iter1,Testing Accuracy0.8331 Iter2,Testing Accuracy0.8592 Iter3,Testing Accuracy0.8707 Iter4,Testing Accuracy0.8779 Iter5,Testing Accuracy0.8814 Iter6,Testing Accuracy0.885 Iter7,Testing Accuracy0.8884 Iter8,Testing Accuracy0.8917 Iter9,Testing Accuracy0.8936 Iter10,Testing Accuracy0.8962 Iter11,Testing Accuracy0.8968 Iter12,Testing Accuracy0.8982 Iter13,Testing Accuracy0.8994 Iter14,Testing Accuracy0.9009 Iter15,Testing Accuracy0.9023 Iter16,Testing Accuracy0.9031 Iter17,Testing Accuracy0.9037 Iter18,Testing Accuracy0.9044 Iter19,Testing Accuracy0.9053 Iter20,Testing Accuracy0.9053
準確率大概在90%左右。