TensorFlow Estimator 教程之----Checkpoints
本文介紹了 Estimators 模型的儲存和恢復。
TensorFlow提供了兩種模型格式:
- checkpoints:一種與語言相關的序列化格式。
- SavedModel:一種獨立於語言且可恢復的序列化格式。
本文主要講述checkpoints相關內容。關於 SavedModel
的更多細節,詳見 Saving and Restoring。
1. Estimator 模型的儲存
Estimators 在訓練過程中會自動將以下內容儲存到硬碟:
- chenkpoints:模型快照。
- event files:用於TensorBoard模型訓練過程視覺化。
通過 model_dir
引數,我們可以指定 Estimator 儲存上述檔案時的目錄。
# 例項化 estimator
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
# 訓練 estimator
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
如下圖所示,第一次呼叫 train
model_dir
目錄。
在類Unix系統中,可以使用 ls
命令來檢視 model_dir
檔案中儲存的內容。
$ ls -l models/iris
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta
通過 ls
命令,我們可以看到,Estimator在step 1(訓練的開始)和step 200(訓練的結尾)儲存了checkpoints 檔案。
1.1 model_dir 的預設目錄
如果你沒有指定 model_dir
引數,Estimator 會將checkpoints 檔案儲存到一個臨時資料夾(基於Python的 tempfile.mkdtemp 函式)。
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3)
print(classifier.model_dir) # 臨時資料夾
1.2 Checkpoints 檔案的儲存頻率
預設情況下,Estimator 會基於以下原則來儲存 checkpoints 檔案。
- 每10分鐘(600秒)儲存一次 checkpoint。
- 在訓練的開始和結尾各儲存一次checkpoint。
- 只儲存最近5個checkpoint。
上面是Checkpoint的預設儲存策略,你可以通過如下步驟來修改該策略:
- 在例項化 Estimator 時,將 RunConfig 物件傳給 config 引數。
my_checkpointing_config = tf.estimator.RunConfig(
save_checkpoints_secs = 20*60, # 每20分鐘儲存一次 checkpoints
keep_checkpoint_max = 10, # 保留最新的10個checkpoints
)
classifier = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris',
config=my_checkpointing_config)
2. 恢復你的模型
第一次呼叫 Estimator 的 train
方法時,TensorFlow會儲存checkpoint檔案到 model_dir 目錄。隨後呼叫 tarin
、evaluate
、predict
方法將進行如下操作:
- Estimator 通過執行
model_fn
來構建模型的計算圖。 - Estimator 從儲存的 checkpoints 中初始化模型引數。
2.1 避免restore錯誤
僅在模型和checkpoint相容的情況下,才能從checkpoint恢復模型的狀態。例如,假設您訓練了DNNClassifier包含兩個隱藏層的 Estimator,每個隱藏層有10個節點:
classifier = tf.estimator.DNNClassifier(
feature_columns=feature_columns,
hidden_units=[10, 10],
n_classes=3,
model_dir='models/iris')
classifier.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
訓練之後,如果您將每個隱藏層中的神經元數量從10更改為20,並嘗試重新訓練模型:
classifier2 = tf.estimator.DNNClassifier(
feature_columns=my_feature_columns,
hidden_units=[20, 20], # Change the number of neurons in the model.
n_classes=3,
model_dir='models/iris')
classifier2.train(
input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
steps=200)
由於checkpoint中的狀態與描述的模型不相容,因此classifier2重新訓練失敗並出現以下錯誤:
...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]
總結
Estimator 對於模型的儲存和恢復有著良好的支援。
- 使用TF的低階API來儲存恢復模型。
- 使用SavedModel格式(獨立於程式語言)來儲存、恢復模型。