1. 程式人生 > 其它 >使用tf.data資料轉換來訓練MNIST資料集

使用tf.data資料轉換來訓練MNIST資料集

技術標籤:TensorFlow神經網路和深度學習tensorflow神經網路深度學習

以MNIST資料集為例來訓練模型

# -*- coding: UTF-8 -*-
"""
Author: LGD
FileName: fashion_mnist_tfdataset
DateTime: 2020/11/26 09:04 
SoftWare: PyCharm
"""
import tensorflow as tf

print('Tensorflow version: {}'.format(tf.__version__))

(train_images,
train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data() # 資料歸一化 train_images = train_images / 255 test_images = test_images / 255 # 建立train_images的Dataset ds_train_img = tf.data.Dataset.from_tensor_slices(train_images) print(ds_train_img) ds_train_label = tf.data.Dataset.from_tensor_slices(
train_labels) print(ds_train_label) # 使用zip將資料合併到一起 ds_train = tf.data.Dataset.zip((ds_train_img, ds_train_label)) print(ds_train) # 對資料做變換,取出10000組資料亂序,迴圈,分批次,每批次資料量為64 ds_train = ds_train.shuffle(10000).repeat().batch(64) # 建立模型 model = tf.keras.Sequential([ tf.keras.layers.Flatten(input_shape=(28,
28)), tf.keras.layers.Dense(128, activation='relu'), tf.keras.layers.Dense(10, activation='softmax') ]) # 編譯模型 model.compile( optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'] ) # # 訓練 # steps_per_epochs = train_images.shape[0] // 64 # 每次迭代64張圖片,每個epoch迭代的步數 # model.fit( # ds_train, # epochs=5, # steps_per_epoch=steps_per_epochs # ) # 建立test_images的Dataset ds_test = tf.data.Dataset.from_tensor_slices((test_images, test_labels)) ds_test = ds_test.batch(64) # 訓練 steps_per_epochs = train_images.shape[0] // 64 # 每次迭代64張圖片,每個epoch迭代的步數 model.fit( ds_train, epochs=5, steps_per_epoch=steps_per_epochs, validation_data=ds_test, validation_steps=10000//64 # 由於有迴圈,必須要有step它才知道什麼時候列印一下驗證準確率。 )

獲取MNIST資料集,可以直接是在程式碼載入裡下載,也可以關注下列公眾號加讀者微信,分享百度網盤連結。
在這裡插入圖片描述