1. 程式人生 > 實用技巧 >【tensorflow】神經網路:斷點續訓

【tensorflow】神經網路:斷點續訓

斷點續訓,即在一次訓練結束後,可以先將得到的最優訓練引數儲存起來,待到下次訓練時,直接讀取最優引數,在此基礎上繼續訓練。

讀取模型引數:

儲存模型引數的檔案格式為 ckpt(checkpoint)。

生成 ckpt 檔案時,會同步生成索引表,所以可通過判斷是否存在索引表來判斷是否存在模型引數。

# 模型引數儲存路徑
checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt"  
if os.path.exists(checkpoint_save_path + ".index"): model.load_weights(checkpoint_save_path)

儲存模型引數:

# 定義回撥函式,在模型訓練時,回撥函式會被執行,完成保留引數操作
cp_callback = tf.keras.callbacks.ModelCheckpoint(
  # 檔案儲存路徑
  filepath=checkpoint_save_path,

  # 是否只保留模型引數
  save_weights_only=True,

  # 是否只保留最優結果
  save_best_only=True
)

# 執行訓練過程,儲存新的訓練引數
history = model.fit(x_train, y_train,
            batch_size=32, epochs=5,
            validation_data
=(x_test, y_test),             validation_freq=1,             callbacks=[cp_callback])

程式碼:

import tensorflow as tf
import os

# 讀取輸入特徵和標籤
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# 資料歸一化,減小計算量,方便神經網路吸收
x_train, x_test = x_train/255.0, x_test/255.0

#
宣告網路結構 model = tf.keras.models.Sequential([ tf.keras.layers.Flatten(), tf.keras.layers.Dense(128, activation="relu"), tf.keras.layers.Dense(10, activation="softmax") ]) # 配置訓練方法 model.compile(optimizer="adam", loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=[tf.keras.metrics.sparse_categorical_accuracy]) # 如果存在引數檔案,直接讀取,在此基礎上繼續訓練 checkpoint_save_path = "class4/MNIST_FC/checkpoint/mnist.ckpt" # 模型引數儲存路徑 if os.path.exists(checkpoint_save_path + ".index"): model.load_weights(checkpoint_save_path) # 定義回撥函式,在模型訓練時,完成保留引數操作 cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_save_path, save_weights_only=True, save_best_only=True) # 執行訓練過程,儲存新的訓練引數 history = model.fit(x_train, y_train, batch_size=32, epochs=5, validation_data=(x_test, y_test), validation_freq=1, callbacks=[cp_callback]) # 列印網路結構和引數 model.summary()