1. 程式人生 > 實用技巧 >【tensorflow】搭建 手寫數字識別 神經網路模型:Sequential() / 神經網路類class 兩種方法

【tensorflow】搭建 手寫數字識別 神經網路模型:Sequential() / 神經網路類class 兩種方法

MNIST 資料集一共有 7 萬張圖片,都是28x28 畫素點的 0~9 手寫數字,其中6 萬用於訓練,1 萬張用於測試。

f.keras + Sequential() 6 步搭建神經網路

import tensorflow as tf

# 讀入訓練所需的輸入特徵和標籤
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 輸入特徵歸一化,減小計算量,方便神經網路吸收
x_train, x_test = x_train/255.0, x_test/255.0

# 搭建網路
model = tf.keras.models.Sequential([ # 將輸入特徵(28x28)拉直為一維陣列(1x748) tf.keras.layers.Flatten(), # 定義第一層網路,有128個神經元 tf.keras.layers.Dense(128, activation="relu"), # 定義第二層網路,有10個神經元 tf.keras.layers.Dense(10, activation="softmax") ]) # 配置訓練方法 model.compile(optimizer=tf.keras.optimizers.Adam(), loss
=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=[tf.keras.metrics.sparse_categorical_accuracy]) # 執行訓練過程 model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1) # 打印出網路結構和引數統計 model.summary()

tf.keras + 神經網路類class 6 步搭建神經網路

import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import Model

# 讀取訓練用的輸入特徵和標籤
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 輸入特徵歸一化,減小計算量,方便神經網路吸收
x_train, x_test = x_train/255.0, x_test/255.0

# 定義神經網路類
class MnistModel(Model):
    def __init__(self):
        super(MnistModel, self).__init__()
        # 定義拉直層
        self.flatten = Flatten()
        # 定義第一層神經網路
        self.d1 = Dense(128, activation="relu")
        # 定義第二層神經網路
        self.d2 = Dense(10, activation="softmax")

    def call(self, x):
        # 將輸入特徵拉直成一維陣列
        x = self.flatten(x)
        # 呼叫剩下兩層神經網路,實現前向傳播
        x = self.d1(x)
        y = self.d2(x)
        return y

# 宣告神經網路物件
model = MnistModel()

# 配置訓練方法
model.compile(optimizer=tf.keras.optimizers.Adam(),
              loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              metrics=[tf.keras.metrics.sparse_categorical_accuracy])

# 執行訓練過程
model.fit(x_train, y_train,
          batch_size=32, epochs=5,
          validation_data=(x_test, y_test),
          validation_freq=1)

# 列印網路結構和引數統計
model.summary()