使用TensorFlow直接獲取處理MNIST資料
阿新 • • 發佈:2019-01-04
MNIST是一個非常有名的手寫體數字識別資料集,TensorFlow對MNIST資料集做了封裝,可以直接呼叫。MNIST資料集包含了60000張圖片作為訓練資料,10000張圖片作為測試資料,每一張圖片都代表了0-9中的一個數字,圖片大小都是28*28。雖然這個資料集只提供了訓練和測試資料,但是為了驗證訓練網路的效果,一般從訓練資料中劃分出一部分資料作為驗證資料,測試神經網路模型在不同引數下的效果。TensorFlow提供了一個類來處理MNIST資料。程式碼如下:
from tensorflow.examples.tutorials.mnist import input_data
#載入MNIST資料集,如果指定地址下沒有下載好的資料,那麼TensorFlow會自動在網站上下載資料
mnist = input_data.read_data_sets("/tensorflow_google")
#列印訓練資料大小
print("Training data size:", mnist.train.num_examples)
#列印驗證集大小
print("Validating data size:", mnist.validation.num_examples)
#列印測試集大小
print("Testing data size:", mnist.test.num_examples)
#列印訓練樣例
print("Example training data", mnist.train .images[0])
#列印訓練樣例的標籤
print("Example training data label:", mnist.train.labels[0])
>>Training data size: 55000
Validating data size: 5000
Testing data size: 10000
Example training data [ 0. ... 0. ]
Example training data label: 7
處理後的每一張圖片是一個長度為784(28*28)的一維陣列,陣列中的資料為圖片的畫素,畫素元素取值範圍為0-1,代表了顏色的深淺,其中0為白色,1為黑色。為了可以使用隨機梯度下降,input_data.read_data_sets生成的類還提供了mnist.train.next_batch,可以從素有的訓練資料中讀取一小部分作為一個訓練batch,例如:
batch_size = 200
xs, ys = mnist.train.next_batch(batch_size) #xs是資料,ys是對應的標籤
print("X shape", xs.shape)
print("Y shape", ys.shape)
>>X shape (200, 784) #X是200*784的陣列
Y shape (200,) #Y是200維的一維陣列