深度學習 Tensorflow(五)
阿新 • • 發佈:2020-10-14
儲存模型
一、儲存整個模型
整個模型可以儲存到一個檔案當中,其中包含權重值、模型配置乃至優化器配置。這樣,就可以為模型設定檢查點,並稍後從完全相同的狀態繼續訓練,而無需訪問原始程式碼。
在 Keras 中儲存完全可正常使用的模型非常有用,可以在 TensorFlow.js 中載入他們,然後在網路瀏覽器中訓練和執行他們。
Keras 使用 HDF5 標準提供基本的儲存格式。
model.save('less_model_10_14.h5') # 儲存模型,h5 格式
# 使用儲存的模型
new_model = tf.keras.models.load_model('less_model_10_14.h5 ')
二、儲存模型架構
儲存模型架構,模型的層數設定,不儲存權重和優化器設定
json_config = model.to_json()
# 模型恢復,重建
reinitialized_model = tf.keras.models.model_from_json(json_config)
reinitialized_model.summary()
# 重建的模型沒有經過配置,權重是隨機的,使用時需要配置優化器
reinitialized_model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy ',
metrics=['acc']
)
三、儲存模型權重
儲存模型的狀態(權重值),可以通過 get_weights() 獲取權值,通過 set_weights() 設定權重值
weights = model.get_weights()
reinitialized_model.set_weights(weights)
reinitialized_model.evaluate(test_image, test_label, verbose=0)
# 儲存權重到本地檔案
model.save_weights('less_weights_10_14.h5 ')
# 載入權重
reinitialized_model.load_weights('less_weights_10_14.h5')
reinitialized_model.evaluate(test_image, test_label, verbose=0)
四、在訓練期間儲存檢查點
在訓練期間或尋來結束自動儲存檢查點,這樣可以使用經過訓練的模型,無需重新訓練該模型,或從上次暫停的地方繼續訓練,以防訓練過程中斷。
回撥函式:tf.keras.callbacks.ModelCheckpoint
checkpoint_path = 'training/check_point_10_14.ckpt'
cp_callback = tf.keras.callbacks.ModelCheckpoint(checkpoint_path, save_weights_only=True) # 僅僅儲存了權重值,如果儲存整個模型,呼叫方式同上
model.fit(train_image, train_label, epochs=3, callbacks=[cp_callback])
# 當網路重新開始時呼叫檢查點
model.load_weights(checkpoint_path)
五、自定義訓練中儲存檢查點
cp_dir = './customtrain'
cp_prefix = os.path.join(cp_dir, 'ckpt') # 新增檔案字首
checkpoint = tf.train.Checkpoint(optimizer = optimizer, model = model)
def train():
for epoch in range(5):
for (batch, (images, labels)) in enumerate(dataset):
train_step(model, images, labels)
print('Epoch{} loss is {}'.format(epoch, train_loss.result()))
print('Epoch{} accuracy is {}'.format(epoch, train_accuracy.result()))
train_loss.reset_states()
train_accuracy.reset_states()
if (epoch + 1) % 2 == 0: # 儲存的頻率
checkpoint.save(file_prefix = cp_prefix)
# 恢復模型
tf.train.latest_checkpoint(cp_dir) # 最新的檢查點
checkpoint.restore(tf.train.latest_checkpoint(cp_dir)) # 通過模型的 name 屬性對應來進行恢復