tensorflow(3)——Mnist資料集
阿新 • • 發佈:2019-01-12
學習《Tensorflow入門教程》記錄
一、載入資料集
import numpy as np import tensorflow as tf import matplotlib.pyplot as plt from tensorflow.examples.tutorials.mnist import input_data mnist = input_data.read_data_sets('data/', one_hot=True) print print (" 型別是 %s" % (type(mnist))) print (" 訓練資料有 %d" % (mnist.train.num_examples)) print (" 測試資料有 %d" % (mnist.test.num_examples))
執行結果:
型別是 <class 'tensorflow.contrib.learn.python.learn.datasets.base.Datasets'>
訓練資料有 55000
測試資料有 10000
注意:如果Mnist載入失敗,可以自行下載資料集,放在當前路徑的data資料夾下。
二、資料集的規格
trainimg = mnist.train.images trainlabel = mnist.train.labels testimg = mnist.test.images testlabel = mnist.test.labels # 28 * 28 * 1 print (" 資料型別 is %s" % (type(trainimg))) print (" 標籤型別 %s" % (type(trainlabel))) print (" 訓練集的shape %s" % (trainimg.shape,)) print (" 訓練集的標籤的shape %s" % (trainlabel.shape,)) print (" 測試集的shape' is %s" % (testimg.shape,)) print (" 測試集的標籤的shape %s" % (testlabel.shape,))
執行結果:
資料型別 is <class 'numpy.ndarray'>
標籤型別 <class 'numpy.ndarray'>
訓練集的shape (55000, 784)
訓練集的標籤的shape (55000, 10)
測試集的shape' is (10000, 784)
測試集的標籤的shape (10000, 10)
三、資料集的形式
nsample = 2 randidx = np.random.randint(trainimg.shape[0], size=nsample) #隨機選行 for i in randidx: curr_img = np.reshape(trainimg[i, :], (28, 28)) # 28 by 28 matrix curr_label = np.argmax(trainlabel[i, :] ) # Label plt.matshow(curr_img, cmap=plt.get_cmap('gray')) print ("" + str(i) + "th 訓練資料 " + "0標籤是 " + str(curr_label)) plt.show()
結果是:
48118th 訓練資料 標籤是 0
8268th 訓練資料 標籤是 0