1. 程式人生 > 實用技巧 >深度學習 Tensorflow(五)

深度學習 Tensorflow(五)

儲存模型

一、儲存整個模型

整個模型可以儲存到一個檔案當中,其中包含權重值、模型配置乃至優化器配置。這樣,就可以為模型設定檢查點,並稍後從完全相同的狀態繼續訓練,而無需訪問原始程式碼。

在 Keras 中儲存完全可正常使用的模型非常有用,可以在 TensorFlow.js 中載入他們,然後在網路瀏覽器中訓練和執行他們。

Keras 使用 HDF5 標準提供基本的儲存格式。

model.save('less_model_10_14.h5')    # 儲存模型,h5 格式
# 使用儲存的模型
new_model = tf.keras.models.load_model('less_model_10_14.h5
')

二、儲存模型架構

儲存模型架構,模型的層數設定,不儲存權重和優化器設定

json_config = model.to_json()
# 模型恢復,重建
reinitialized_model = tf.keras.models.model_from_json(json_config)
reinitialized_model.summary()
# 重建的模型沒有經過配置,權重是隨機的,使用時需要配置優化器
reinitialized_model.compile(optimizer='adam',
                            loss='sparse_categorical_crossentropy
', metrics=['acc'] )

三、儲存模型權重

儲存模型的狀態(權重值),可以通過 get_weights() 獲取權值,通過 set_weights() 設定權重值

weights = model.get_weights()
reinitialized_model.set_weights(weights)
reinitialized_model.evaluate(test_image, test_label, verbose=0)

# 儲存權重到本地檔案
model.save_weights('less_weights_10_14.h5
') # 載入權重 reinitialized_model.load_weights('less_weights_10_14.h5') reinitialized_model.evaluate(test_image, test_label, verbose=0)

四、在訓練期間儲存檢查點

在訓練期間或尋來結束自動儲存檢查點,這樣可以使用經過訓練的模型,無需重新訓練該模型,或從上次暫停的地方繼續訓練,以防訓練過程中斷。

回撥函式:tf.keras.callbacks.ModelCheckpoint

checkpoint_path = 'training/check_point_10_14.ckpt'
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True)    # 僅僅儲存了權重值,如果儲存整個模型,呼叫方式同上

model.fit(train_image, train_label, epochs=3, callbacks=[cp_callback])

# 當網路重新開始時呼叫檢查點
model.load_weights(checkpoint_path)

五、自定義訓練中儲存檢查點

cp_dir = './customtrain'
cp_prefix = os.path.join(cp_dir, 'ckpt')    # 新增檔案字首
checkpoint = tf.train.Checkpoint(optimizer = optimizer, model = model)
def train():
    for epoch in range(5):
        for (batch, (images, labels)) in enumerate(dataset):
            train_step(model, images, labels)
        print('Epoch{} loss is {}'.format(epoch, train_loss.result()))
        print('Epoch{} accuracy is {}'.format(epoch, train_accuracy.result()))
        train_loss.reset_states()
        train_accuracy.reset_states()
        if (epoch + 1) % 2 == 0:    # 儲存的頻率
            checkpoint.save(file_prefix = cp_prefix)

# 恢復模型
tf.train.latest_checkpoint(cp_dir)    # 最新的檢查點
checkpoint.restore(tf.train.latest_checkpoint(cp_dir))    # 通過模型的 name 屬性對應來進行恢復