Tensorflow mnist資料集操作
import sys import numpy as np import matplotlib.pyplot as plt import tensorflow as tf import argparse import matplotlib.image as mpimg from skimage import io #利用skilt-image 將圖片讀成numpy陣列 from tensorflow.examples.tutorials.mnist import input_data #引入tensorflow的mnist手寫數字識別庫 class Mnister: #定義一個類 def _init_(self): pass def learn_mnist(self): mnist = input_data.read_data_sets('datasets',one_hot = True) #tensorflow的input_data和read_data_sets,第一個是資料集存放路徑,第二個引數是標籤集的格式 #one_hot獨熱編碼,N個維度來對N個類別進行編碼,並且對於每個類別,只有一個維度有效,記作數字1 ;其它維度均記作數字0. x_train_data = mnist.train.images #取出輸入訓練訊號資料X_train,為矩陣形式,一行代表一個樣本有784維(28*28) y_train_label = mnist.train.labels #取出輸入訓練訊號標籤y_train,為one_hot為行的矩陣形式,每一行代表對應樣本的正確結果,在這是10維(0—9) x_validation_data = mnist.validation.images #取出驗證集的資料集 y_validation_label = mnist.validation.labels #取出驗證集的資料集對應的標籤(10維0-9) x_test_data = mnist.test.images #取出測試集的資料集 y_test_label = mnist.test.labels #取出測試集的資料集對應的標籤 print('x_train_data:{0} y_train_label:{1}'.format(x_train_data.shape,y_train_label.shape)) print(' x_validation_data:{0} y_validation_label:{1}'.format( x_validation_data.shape,y_validation_label.shape)) print('x_test_data:{0} y_test_label:{1}'.format(x_test_data.shape,y_test_label.shape)) image_raw = (x_train_data[1]*255).astype(int)#將第二個資料拿出,由0—1的浮點數,轉化為0-255的整數灰度值。 image = image_raw.reshape(28,28) #將784維的行向量轉化為28*28的矩陣 label = y_train_label[1] #讀取該樣本的正確結果標籤 idx = 0 #定義索引 for item in label: if 1==item: break #標籤向量one-hot此元素為1時終止迴圈, idx += 1 plt.title('digit:{0}'.format(idx))#將樣本標籤的正確結果顯示在圖片上 plt.imshow(image,cmap='gray') #以灰度來顯示圖象 plt.show() #顯示圖片 def main(self): mnister = Mnister() mnister.learn_mnist() if '__main__' == __name__: parser = argparse.ArgumentParser() parser.add_argument('--datda_dir',type = str,default='datasets', help = 'Directory for storing input data') FLAGS,unparsed = parser.parse_known_args() tf.app.run(main = main,argv = [sys.argv[0]]+unparsed)