1. 程式人生 > >使用TensorFlow直接獲取處理MNIST資料

使用TensorFlow直接獲取處理MNIST資料

MNIST是一個非常有名的手寫體數字識別資料集,TensorFlow對MNIST資料集做了封裝,可以直接呼叫。MNIST資料集包含了60000張圖片作為訓練資料,10000張圖片作為測試資料,每一張圖片都代表了0-9中的一個數字,圖片大小都是28*28。雖然這個資料集只提供了訓練和測試資料,但是為了驗證訓練網路的效果,一般從訓練資料中劃分出一部分資料作為驗證資料,測試神經網路模型在不同引數下的效果。TensorFlow提供了一個類來處理MNIST資料。程式碼如下:

from tensorflow.examples.tutorials.mnist import input_data

#載入MNIST資料集,如果指定地址下沒有下載好的資料,那麼TensorFlow會自動在網站上下載資料
mnist = input_data.read_data_sets("/tensorflow_google") #列印訓練資料大小 print("Training data size:", mnist.train.num_examples) #列印驗證集大小 print("Validating data size:", mnist.validation.num_examples) #列印測試集大小 print("Testing data size:", mnist.test.num_examples) #列印訓練樣例 print("Example training data", mnist.train
.images[0]) #列印訓練樣例的標籤 print("Example training data label:", mnist.train.labels[0]) >>Training data size: 55000 Validating data size: 5000 Testing data size: 10000 Example training data [ 0. ... 0. ] Example training data label: 7

處理後的每一張圖片是一個長度為784(28*28)的一維陣列,陣列中的資料為圖片的畫素,畫素元素取值範圍為0-1,代表了顏色的深淺,其中0為白色,1為黑色。為了可以使用隨機梯度下降,input_data.read_data_sets生成的類還提供了mnist.train.next_batch,可以從素有的訓練資料中讀取一小部分作為一個訓練batch,例如:

batch_size = 200
xs, ys = mnist.train.next_batch(batch_size) #xs是資料,ys是對應的標籤
print("X shape", xs.shape)
print("Y shape", ys.shape)

>>X shape (200, 784)  #X是200*784的陣列
Y shape (200,)  #Y是200維的一維陣列