1. 程式人生 > >Tensorflow學習之路(一):從MNIST資料集開始

Tensorflow學習之路(一):從MNIST資料集開始

MNIST資料集簡單介紹: MNIST 資料集可在 http://yann.lecun.com/exdb/mnist/ 獲取, 它包含了四個部分: Training set images: train-images-idx3-ubyte.gz (9.9 MB, 解壓後 47 MB, 包含 60,000 個樣本) Training set labels: train-labels-idx1-ubyte.gz (29 KB, 解壓後 60 KB, 包含 60,000 個標籤) Test set images: t10k-images-idx3-ubyte.gz (1.6 MB, 解壓後 7.8 MB, 包含 10,000 個樣本) Test set labels: t10k-labels-idx1-ubyte.gz (5KB, 解壓後 10 KB, 包含 10,000 個標籤) MNIST 資料集來自美國國家標準與技術研究所, National Institute of Standards and Technology (NIST). 訓練集 (training set) 由來自 250 個不同人手寫的數字構成, 其中 50% 是高中學生, 50% 來自人口普查局 (the Census Bureau) 的工作人員. 測試集(test set) 也是同樣比例的手寫數字資料. 下載之後結果: 資料集下載結果

然後本文以Jupyter Notebook為工具對MNIST資料集進行分類 程式碼如下:

import tensorflow as tf
import tensorflow.examples.tutorials.mnist.input_data as input_data
#載入資料集
mnist = input_data.read_data_sets("E:/tensorflow視訊學習/資料/第三週/MNIST_data",one_hot=True)#下載好匯入
#mnist = input_data.read_data_sets("MNIST_data",one_hot=True)#執行程式時從網路下載網路較好可以差的話會有問題
#每個批次的大小
batch_size = 100
#計算一共有多少個批次
n_batch = mnist.train.num_examples // batch_size

#定義兩個placeholder
x = tf.placeholder(tf.float32,[None,784])
y = tf.placeholder(tf.float32,[None,10])

#建立一個簡單的神經網路
W = tf.Variable(tf.zeros([784,10]))
b = tf.Variable(tf.zeros([10]))
prediction = tf.nn.softmax(tf.matmul(x,W)+b)

#二次代價函式
loss = tf.reduce_mean(tf.square(y-prediction))
#使用梯度下降法
train_step = tf.train.GradientDescentOptimizer(0.2).minimize(loss)

#初始化變數
init = tf.global_variables_initializer()

#結果存放在一個布林型列表中
correct_prediction = tf.equal(tf.argmax(y,1),tf.argmax(prediction,1))#argmax返回一維張量中最大的值所在的位置
#求準確率
accuracy = tf.reduce_mean(tf.cast(correct_prediction,tf.float32))

with tf.Session() as sess:
    sess.run(init)
    for epoch in range(21):
        for batch in range(n_batch):
            batch_xs,batch_ys =  mnist.train.next_batch(batch_size)
            sess.run(train_step,feed_dict={x:batch_xs,y:batch_ys})
        
        acc = sess.run(accuracy,feed_dict={x:mnist.test.images,y:mnist.test.labels})
        print("Iter " + str(epoch) + ",Testing Accuracy " + str(acc))

然後把程式碼儲存為.ipynb形式的了就直接執行結果如下: 執行結果