1. 程式人生 > >TensorFlow Estimator 教程之----Checkpoints

TensorFlow Estimator 教程之----Checkpoints

本文介紹了 Estimators 模型的儲存和恢復。

TensorFlow提供了兩種模型格式:

  • checkpoints:一種與語言相關的序列化格式。
  • SavedModel:一種獨立於語言且可恢復的序列化格式。

本文主要講述checkpoints相關內容。關於 SavedModel 的更多細節,詳見 Saving and Restoring

1. Estimator 模型的儲存

Estimators 在訓練過程中會自動將以下內容儲存到硬碟:

  • chenkpoints:模型快照。
  • event files:用於TensorBoard模型訓練過程視覺化。

通過 model_dir 引數,我們可以指定 Estimator 儲存上述檔案時的目錄。

# 例項化 estimator
classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

# 訓練 estimator
classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
    steps=200)

如下圖所示,第一次呼叫 train

方法會新增 checkpoints和其他檔案到 model_dir 目錄。 在這裡插入圖片描述 在類Unix系統中,可以使用 ls 命令來檢視 model_dir 檔案中儲存的內容。

$ ls -l models/iris
checkpoint
events.out.tfevents.timestamp.hostname
graph.pbtxt
model.ckpt-1.data-00000-of-00001
model.ckpt-1.index
model.ckpt-1.meta
model.ckpt-200.data-00000-of-00001
model.ckpt-200.index
model.ckpt-200.meta

通過 ls 命令,我們可以看到,Estimator在step 1(訓練的開始)和step 200(訓練的結尾)儲存了checkpoints 檔案。

1.1 model_dir 的預設目錄

如果你沒有指定 model_dir 引數,Estimator 會將checkpoints 檔案儲存到一個臨時資料夾(基於Python的 tempfile.mkdtemp 函式)。

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3)

print(classifier.model_dir) # 臨時資料夾

1.2 Checkpoints 檔案的儲存頻率

預設情況下,Estimator 會基於以下原則來儲存 checkpoints 檔案。

  • 每10分鐘(600秒)儲存一次 checkpoint。
  • 在訓練的開始和結尾各儲存一次checkpoint。
  • 只儲存最近5個checkpoint。

上面是Checkpoint的預設儲存策略,你可以通過如下步驟來修改該策略:

  1. 在例項化 Estimator 時,將 RunConfig 物件傳給 config 引數。
下面的程式碼將checkpoint儲存間隔設定為20分鐘,並且儲存最近的10個checkpoints:
my_checkpointing_config = tf.estimator.RunConfig(
    save_checkpoints_secs = 20*60,  # 每20分鐘儲存一次 checkpoints
    keep_checkpoint_max = 10,       # 保留最新的10個checkpoints
)

classifier = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris',
    config=my_checkpointing_config)

2. 恢復你的模型

第一次呼叫 Estimator 的 train 方法時,TensorFlow會儲存checkpoint檔案到 model_dir 目錄。隨後呼叫 tarinevaluatepredict 方法將進行如下操作:

  1. Estimator 通過執行 model_fn 來構建模型的計算圖。
  2. Estimator 從儲存的 checkpoints 中初始化模型引數。

在這裡插入圖片描述

2.1 避免restore錯誤

僅在模型和checkpoint相容的情況下,才能從checkpoint恢復模型的狀態。例如,假設您訓練了DNNClassifier包含兩個隱藏層的 Estimator,每個隱藏層有10個節點:

classifier = tf.estimator.DNNClassifier(
    feature_columns=feature_columns,
    hidden_units=[10, 10],
    n_classes=3,
    model_dir='models/iris')

classifier.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

訓練之後,如果您將每個隱藏層中的神經元數量從10更改為20,並嘗試重新訓練模型:

classifier2 = tf.estimator.DNNClassifier(
    feature_columns=my_feature_columns,
    hidden_units=[20, 20],  # Change the number of neurons in the model.
    n_classes=3,
    model_dir='models/iris')

classifier2.train(
    input_fn=lambda:train_input_fn(train_x, train_y, batch_size=100),
        steps=200)

由於checkpoint中的狀態與描述的模型不相容,因此classifier2重新訓練失敗並出現以下錯誤:

...
InvalidArgumentError (see above for traceback): tensor_name =
dnn/hiddenlayer_1/bias/t_0/Adagrad; shape in shape_and_slice spec [10]
does not match the shape stored in checkpoint: [20]

總結

Estimator 對於模型的儲存和恢復有著良好的支援。

  • 使用TF的低階API來儲存恢復模型。
  • 使用SavedModel格式(獨立於程式語言)來儲存、恢復模型。

本文的程式碼來源