1. 程式人生 > >Tensorflow(1):MNIST識別自己手寫的數字--入門篇(Softmax迴歸)

Tensorflow(1):MNIST識別自己手寫的數字--入門篇(Softmax迴歸)

  機器學習入門都是從MNIST開始,Tensorflow官方社群提供了十分詳細的教程【MNIST機器學習入門】。但是我們顯然不滿足於僅僅把官方的程式碼複製一遍然後輸出個結果,我們想能不能實現自己手寫數字的識別。
  本文作為Tensorflow入門,結合官方程式碼,利用Softmax迴歸函式,實現模型的訓練、儲存、以及重新載入,完成對自己手寫數字的識別。

1.模型訓練及儲存

  模型我們採用Softmax迴歸函式,具體程式碼參考【MNIST機器學習入門】,這裡用梯度下降演算法以0.01學習率最小化交叉熵對模型進行1000次訓練。

import tensorflow as tf
from
tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('MNIST_data', one_hot=True) # 插入資料 # name在儲存模型時非常有用 x = tf.placeholder("float", [None, 784], name='x') W = tf.Variable(tf.zeros([784, 10]), name='W') b = tf.Variable(tf.zeros([10]), name='b') y = tf.nn.softmax(tf.matmul(x, W) + b, name='y'
) # y預測概率分佈 y_ = tf.placeholder("float", [None, 10]) # y_實際概率分佈 cross_entropy = -tf.reduce_sum(y_ * tf.log(y)) # 交叉熵 # 梯度下降演算法以0.01學習率最小化交叉熵 train_step = tf.train.GradientDescentOptimizer(0.01).minimize(cross_entropy) init = tf.initialize_all_variables() # 初始化變數 sess = tf.Session() sess.run(init) saver = tf.train.Saver() for
i in range(1000): # 開始訓練模型,迴圈1000次 batch_xs, batch_ys = mnist.train.next_batch(100) sess.run(train_step, feed_dict={x: batch_xs, y_: batch_ys}) saver.save(sess, 'minst_model.ckpt') # 儲存模型

  在程式碼最前面,定義張量(變數)時,我們給每個張量(變數)都加了name關鍵字,這個對我們後期再次載入模型很重要。
  在最後,我們利用saver.save()函式,儲存模型。模型名稱為minst_model.ckpt。之後我們可以在資料夾下看到4個檔案:

  • checkpoint: 儲存目錄下所有模型檔案列表
  • minst_model.ckpt.meta :儲存了計算圖的結構,可以理解為模型的結構
  • minst_model.ckpt.index 和 minst_model.ckpt.data-00000-of-00001:儲存了模型中所有變數的值.

2.模型載入

   儲存好模型之後,我們利用自己的圖片對模型進行測試。我們利用windows自帶的畫圖軟體,進行數字手寫,並儲存成28*28畫素的png圖片。例如0,1,2手寫體圖片,如下圖所示:


     

  整個測試程式碼如下:

from PIL import Image, ImageFilter
import tensorflow as tf

def imageprepare():
    file_name = 'pic/2-3.png'  # 圖片路徑
    myimage = Image.open(file_name).convert('L')  # 轉換成灰度圖
    tv = list(myimage.getdata())  # 獲取畫素值
    # 轉換畫素範圍到[0 1], 0是純白 1是純黑
    tva = [(255-x)*1.0/255.0 for x in tv] 
    return tva

result = imageprepare()
init = tf.global_variables_initializer()
saver = tf.train.Saver

with tf.Session() as sess:
    sess.run(init)
    saver = tf.train.import_meta_graph('minst_model.ckpt.meta')  # 載入模型結構
    saver.restore(sess,  'minst_model.ckpt')  # 載入模型引數

    graph = tf.get_default_graph()  # 計算圖
    x = graph.get_tensor_by_name("x:0")  # 從模型中獲取張量x
    y = graph.get_tensor_by_name("y:0")  # 從模型中獲取張量y

    prediction = tf.argmax(y, 1)
    predint = prediction.eval(feed_dict={x: [result]}, session=sess)
    print(predint[0])

  在載入模型時,我們先用tf.train.import_meta_graph()載入模型的結構,之後利用saver.restore()載入模型的訓練好的引數。graph.get_tensor_by_name()依照名字(name)從模型中獲取張量。所以前面在儲存模型時我們給每個張量和變數都加了name關鍵字
  關於如何儲存和載入訓練模型可以參見部落格【TensorFlow儲存還原模型的正確方式】

3.識別結果

  輸出的識別結果如下所示:


    

     

    

  經測試,該方法基本識別率可以達到90%左右。所以基本可以滿足要求。

4.注意事項

  最早時,我手寫數字進行識別時,發現準確率很低。
  後來發現原因是:(1)我自己手動畫的數字線條太細了;(2)畫的有些數字在圖片中的位置沒有位於中心;(3)訓練集是西方的手寫數字,和中國的手寫數字習慣不同。下面是官方的訓練資料中的部分數字。


  在畫圖時,數字效果(畫筆粗細等)儘量和上面訓練集保持一致,就會得到較高的識別率!
  是以為記!