MNIST手寫體識別--tensorflow
阿新 • • 發佈:2018-12-03
MNIST手寫體識別--tensorflow
對於tensorflow給出的幾個版本的手寫體識別的程式碼進行分析。其中tensorflow的mnist程式碼在https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/tutorials/mnist1:softmax版本
# Copyright 2015 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A very simple MNIST classifier. See extensive documentation at https://www.tensorflow.org/get_started/mnist/beginners """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import argparse import sys from tensorflow.examples.tutorials.mnist import input_data import tensorflow as tf FLAGS = None def main(_): # Import data mnist = input_data.read_data_sets(FLAGS.data_dir, one_hot=True) # Create the model x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.matmul(x, W) + b # Define loss and optimizer y_ = tf.placeholder(tf.float32, [None, 10]) # The raw formulation of cross-entropy, # # tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(tf.nn.softmax(y)), # reduction_indices=[1])) # # can be numerically unstable. # # So here we use tf.nn.softmax_cross_entropy_with_logits on the raw # outputs of 'y', and then average across the batch. cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy) sess = tf.InteractiveSession() tf.global_variables_initializer().run() # Train for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) # Test trained model correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1)) accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32)) print(sess.run(accuracy, feed_dict={x: mnist.test.images, y_: mnist.test.labels})) if __name__ == '__main__': parser = argparse.ArgumentParser() parser.add_argument('--data_dir', type=str, default='/tmp/tensorflow/mnist/input_data', help='Directory for storing input data') FLAGS, unparsed = parser.parse_known_args() tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
這裡引入的是input_data然後得到資料集,但是其實這裡input_data只是一個為了更加形象化代表得到資料集的模組,input_data裡面從mnist.py裡面匯入了read_data_sets方法,而mnist.py包含了獲取手寫體的所有方法。首先,mnist.py會檢測是否已經下載好了資料集,根據上面的程式碼,可以看到預設檢測的地址是/tmp/tensorflow/mnist/input_data,其實這裡如果不是linux型別的機器,那麼就是該驅動盤,比如說程式碼在e盤下面,那麼檔案及地址就是e:/tmp/tensorflow/mnist/input_data。如果下載好了,那麼就會解壓資料集,轉化為image-[index, y, x, depth],label-[index]的形式,並且進行one-hot操作。每個檔案都存在一定的格式,應該是 2051 num_image rows cols data的順序,num_image是圖片大小,然後行列個數。最後得到的就是train(60000,784),test(10000,784)資料集,數值基於0到1之間。這裡實際上定義的就是一個二維陣列,有60000行,784列。每個都包含images和labels.通過mnist.train.images和mnist.train.labels等來進行索取。labels使用one-hot進行處理的,所有如果手寫體分類存在10類的話,labels就是一個10列的陣列。
# Create the model x = tf.placeholder(tf.float32, [None, 784]) W = tf.Variable(tf.zeros([784, 10])) b = tf.Variable(tf.zeros([10])) y = tf.matmul(x, W) + b
cross_entropy = tf.reduce_mean( tf.nn.softmax_cross_entropy_with_logits(labels=y_, logits=y)) train_step = tf.train.GradientDescentOptimizer(0.5).minimize(cross_entropy)
sess = tf.InteractiveSession() tf.global_variables_initializer().run()定義的變數必須要使用顯示的初始化函式。
for _ in range(1000): batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys})主要的優化過程在這裡,外層迴圈代表的是梯度下降過程,每一次迭代相當於一步,使用的100個樣本點的作用是小批梯度,因為批梯度演算法每次更新使用的是全部的樣本點,這樣導致計算複雜度很高。 所以這裡使用的是小樣本來進行每次的迭代。需要說明的一點是y_其實就是標籤。後面的程式碼用於評估模型,這裡不做論述。還有一點,就是在定義x的時候使用的是[None,784],原因是後面再迭代的時候需要傳入x,而使用None的話,傳入任何大小都行。
檢視原文: http://www.hahaszj.top/uncategorized/mnist%e6%89%8b%e5%86%99%e4%bd%93%e8%af%86%e5%88%ab-tensorflow/186