1. 程式人生 > 其它 >TensorFlow學習筆記--Mnist全連線模型實踐

TensorFlow學習筆記--Mnist全連線模型實踐

import os
from tensorflow.keras.datasets import mnist
import tensorflow as tf
from tensorflow.python.keras import Model
from tensorflow.python.keras.layers import Flatten, Dense

(x_train, y_train), (x_test, y_test) = mnist.load_data()

x_train, x_test = x_train/255.0, x_test/255.0

checkpoint_save_path 
= './checkpoint/model.ckpt' # 搭建模型類 class MnistModel(Model): def __init__(self): super(MnistModel, self).__init__() self.flatten = Flatten() self.dense1 = Dense(128, activation='relu') self.dense2 = Dense(10, activation='softmax') def call(self, x): x
= self.flatten(x) x = self.dense1(x) y = self.dense2(x) return y model = MnistModel() # 模型優化 model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=['sparse_categorical_accuracy']) # callback儲存模型 model_callback
= tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) # 曾經儲存過,直接載入權重引數 if os.path.exists(checkpoint_save_path + '.index'): model.load_weights(checkpoint_save_path) # 開始訓練 model.fit(x=x_train, y=y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), callbacks=[model_callback]) # 結果總覽 model.summary() # 儲存模型引數到文字,方便檢視 # with open('./weight.txt', 'w') as f: # for i in model.trainable_variables: # f.write(str(i.name) + '\n') # f.write(str(i.shape) + '\n') # # f.write(str(i.numpy()) + '\n') # 這行有問題