TensorFlow學習筆記--Mnist全連線模型實踐
阿新 • • 發佈:2021-12-10
import os from tensorflow.keras.datasets import mnist import tensorflow as tf from tensorflow.python.keras import Model from tensorflow.python.keras.layers import Flatten, Dense (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train, x_test = x_train/255.0, x_test/255.0 checkpoint_save_path= './checkpoint/model.ckpt' # 搭建模型類 class MnistModel(Model): def __init__(self): super(MnistModel, self).__init__() self.flatten = Flatten() self.dense1 = Dense(128, activation='relu') self.dense2 = Dense(10, activation='softmax') def call(self, x): x= self.flatten(x) x = self.dense1(x) y = self.dense2(x) return y model = MnistModel() # 模型優化 model.compile(optimizer=tf.keras.optimizers.Adam(), loss=tf.keras.losses.sparse_categorical_crossentropy, metrics=['sparse_categorical_accuracy']) # callback儲存模型 model_callback= tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) # 曾經儲存過,直接載入權重引數 if os.path.exists(checkpoint_save_path + '.index'): model.load_weights(checkpoint_save_path) # 開始訓練 model.fit(x=x_train, y=y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), callbacks=[model_callback]) # 結果總覽 model.summary() # 儲存模型引數到文字,方便檢視 # with open('./weight.txt', 'w') as f: # for i in model.trainable_variables: # f.write(str(i.name) + '\n') # f.write(str(i.shape) + '\n') # # f.write(str(i.numpy()) + '\n') # 這行有問題