1. 程式人生 > 其它 >TF.Learn 手寫文字識別

TF.Learn 手寫文字識別

minist問題

  • 計算機視覺領域的Hello world
  • 給定55000個圖片,處理成28*28的二維矩陣,矩陣中每個值表示一個畫素點的灰度,作為feature
  • 給定每張圖片對應的字元,作為label,總共有10個label,是一個多分類問題

Tensor Flow

  • 可以按教程用Docker安裝,也可以直接在Linux上安裝
  • 你可能會擔心,不用Docker的話怎麼開那個notebook呢?其實notebook就在主講人的Github頁(https://github.com/random-forests/tutorials)上
  • 可以用這個Chrome外掛:npviewer(https://chrome.google.com/webstore/detail/open-in-nbviewer/ihlhlehlibooakiicbiakgojckpnlali?hl=zh-CN)直接在瀏覽器中閱讀ipynb格式的檔案,而不用在本地啟動iPython notebook
  • 我們的教程在這裡:ep7.ipynb(https://github.com/random-forests/tutorials/blob/master/ep7.ipynb)
  • 把程式碼從ipython notebook中整理出來:tflearn_mnist.py(https://github.com/ahangchen/GoogleML/blob/master/src/tflearn_mnist.py)

程式碼分析

下載資料集

mnist = learn.datasets.load_dataset('mnist')

恩,就是這麼簡單,一行程式碼下載解壓mnist資料,每個img已經灰度化成長784的陣列,每個label已經one-hot成長度10的陣列

numpy讀取影象到記憶體,用於後續操作,包括訓練集(只取前10000個)和驗證集

data = mnist.train.images labels = np.asarray(mnist.train.labels, dtype=np.int32) test_data = mnist. test.images test_labels = np.asarray(mnist.test.labels, dtype=np.int32) max_examples = 10000 data = data[:max_examples] labels = labels[:max_examples]

視覺化影象

def display(i):

img = test_data[i] plt.title('Example %d. Label: %d' % (i, test_labels[i])) plt.imshow(img.reshape((28, 28)), cmap=plt.cm.gray_r) plt.show()

用matplotlib展示灰度圖

訓練分類器

提取特徵(這裡每個圖的特徵就是784個畫素值)

feature_columns = learn.infer_real_valued_columns_from_input(data)

建立線性分類器並訓練

classifier = learn.LinearClassifier(feature_columns=feature_columns, n_classes=10) classifier.fit(data, labels, batch_size=100, steps=1000)

注意要制定n_classes為labels的數量

  • 分類器實際上是在根據每個feature判斷每個label的可能性,
  • 不同的feature有的重要,有的不重要,所以需要設定不同的權重
  • 一開始權重都是隨機的,在fit的過程中,實際上就是在調整權重
  • 最後可能性最高的label就會作為預測輸出
  • 傳入測試集,預測,評估分類效果
result = classifier.evaluate(test_data, test_labels)print result["accuracy"]

速度非常快,而且準確率達到91.4%

可以只預測某張圖,並檢視預測是否跟實際圖形一致

# here's one it gets right print ("Predicted %d, Label: %d" % (classifier.predict(test_data[0]), test_labels[0])) display(0) # and one it gets wrong print ("Predicted %d, Label: %d" % (classifier.predict(test_data[8]), test_labels[8])) display(8)

視覺化權重以瞭解分類器的工作原理

weights = classifier.weights_
a.imshow(weights.T[i].reshape(28, 28), cmap=plt.cm.seismic)

weight視覺化

從上圖可知:

1、這裡展示了8個張圖中,每個畫素點(也就是feature)的weights,

2、紅色表示正的權重,藍色表示負的權重

3、作用越大的畫素,它的顏色越深,也就是權重越大

4、所以權重中紅色部分幾乎展示了正確的數字

Next steps

  • TensorFlow Docker images(https://hub.docker.com/r/tensorflow/tensorflow/)
  • TF.Learn Quickstart(https://www.tensorflow.org/versions/r0.9/tutorials/tflearn/index.html)
  • MNIST tutorial(https://www.tensorflow.org/tutorials/mnist/beginners/index.html)
  • Visualizating MNIST(http://colah.github.io/posts/2014-10-Visualizing-MNIST/)
  • Additional notebooks(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/tools/docker/notebooks)
  • More about linear classifiers(https://www.tensorflow.org/versions/r0.10/tutorials/linear/overview.html#large-scale-linear-models-with-tensorflow)
  • Much more about linear classifiers(http://cs231n.github.io/linear-classify/)
  • Additional TF.Learn samples(https://github.com/tensorflow/tensorflow/tree/master/tensorflow/examples/skflow)

Github工程地址 https://github.com/ahangchen/GoogleML