【tensorflow】神經網路:斷點續訓
阿新 • • 發佈:2020-08-21
斷點續訓,即在一次訓練結束後,可以先將得到的最優訓練引數儲存起來,待到下次訓練時,直接讀取最優引數,在此基礎上繼續訓練。
讀取模型引數:
儲存模型引數的檔案格式為 ckpt(checkpoint)。
生成 ckpt 檔案時,會同步生成索引表,所以可通過判斷是否存在索引表來判斷是否存在模型引數。
# 模型引數儲存路徑 checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt"
if os.path.exists(checkpoint_save_path + ".index"): model.load_weights(checkpoint_save_path)
儲存模型引數:
# 定義回撥函式,在模型訓練時,回撥函式會被執行,完成保留引數操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(
# 檔案儲存路徑
filepath=checkpoint_save_path,
# 是否只保留模型引數
save_weights_only=True,
# 是否只保留最優結果
save_best_only=True
)
# 執行訓練過程,儲存新的訓練引數
history = model.fit(x_train, y_train,
batch_size=32, epochs=5,
validation_data =(x_test, y_test),
validation_freq=1,
callbacks=[cp_callback])
程式碼:
import tensorflow as tf
import os
# 讀取輸入特徵和標籤
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()
# 資料歸一化,減小計算量,方便神經網路吸收
x_train, x_test = x_train/255.0, x_test/255.0
# 宣告網路結構
model = tf.keras.models.Sequential([
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dense(10, activation="softmax")
])
# 配置訓練方法
model.compile(optimizer="adam",
loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
metrics=[tf.keras.metrics.sparse_categorical_accuracy])
# 如果存在引數檔案,直接讀取,在此基礎上繼續訓練
checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt" # 模型引數儲存路徑
if os.path.exists(checkpoint_save_path + ".index"):
model.load_weights(checkpoint_save_path)
# 定義回撥函式,在模型訓練時,完成保留引數操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path,
save_weights_only=True,
save_best_only=True)
# 執行訓練過程,儲存新的訓練引數
history = model.fit(x_train, y_train,
batch_size=32, epochs=5,
validation_data=(x_test, y_test),
validation_freq=1,
callbacks=[cp_callback])
# 列印網路結構和引數
model.summary()