TensorFlow2.0(12):模型儲存與序列化
阿新 • • 發佈:2019-12-24
注:本系列所有部落格將持續更新併發布在github上,您可以通過github下載本系列所有文章筆記檔案。
模型訓練好之後,我們就要想辦法將其持久化儲存下來,不然關機或者程式退出後模型就不復存在了。本文介紹兩種持久化儲存模型的方法:
在介紹這兩種方法之前,我們得先建立並訓練好一個模型,還是以mnist手寫數字識別資料集訓練模型為例:
In [1]:import tensorflow as tf from tensorflow import keras from tensorflow.keras import layers, optimizers, Sequential
model = Sequential([ # 建立模型 layers.Dense(256, activation=tf.nn.relu), layers.Dense(128, activation=tf.nn.relu), layers.Dense(64, activation=tf.nn.relu), layers.Dense(32, activation=tf.nn.relu), layers.Dense(10) ] ) (x_train, y_train), (x_test, y_test) = keras.datasets.mnist.load_data() x_train = x_train.reshape(60000, 784).astype('float32') / 255 x_test = x_test.reshape(10000, 784).astype('float32') / 255 model.compile(loss='sparse_categorical_crossentropy', optimizer=keras.optimizers.RMSprop()) history = model.fit(x_train, y_train, # 進行簡單的1次迭代訓練 batch_size=64, epochs=1)
Train on 60000 samples 60000/60000 [==============================] - 3s 46us/sample - loss: 2.3700
方法一:model.save()¶
通過模型自帶的save()方法可以將模型儲存到一個指定檔案中,儲存的內容包括:
- 模型的結構
- 模型的權重引數
- 通過compile()方法配置的模型訓練引數
- 優化器及其狀態
model.save('mymodels/mnist.h5')
使用save()方法儲存後,在mymodels目錄下就會有一個mnist.h5檔案。需要使用模型時,通過keras.models.load_model()方法從檔案中再次載入即可。
new_model = keras.models.load_model('mymodels/mnist.h5')
WARNING:tensorflow:Sequential models without an `input_shape` passed to the first layer cannot reload their optimizer state. As a result, your model isstarting with a freshly initialized optimizer.
新加載出來的new_model在結構、功能、引數各方面與model是一樣的。
通過save()方法,也可以將模型儲存為SavedModel 格式。SavedModel格式是TensorFlow所特有的一種序列化檔案格式,其他程式語言實現的TensorFlow中同樣支援:
In [5]:model.save('mymodels/mnist_model', save_format='tf') # 將模型儲存為SaveModel格式
WARNING:tensorflow:From /home/chb/anaconda3/envs/study_python/lib/python3.7/site-packages/tensorflow_core/python/ops/resource_variable_ops.py:1781: calling BaseResourceVariable.__init__ (from tensorflow.python.ops.resource_variable_ops) with constraint is deprecated and will be removed in a future version. Instructions for updating: If using Keras pass *_constraint arguments to layers. INFO:tensorflow:Assets written to: mymodels/mnist_model/assetsIn [6]:
new_model = keras.models.load_model('mymodels/mnist_model') # 載入模型
方法二:model.save_weights()¶
save()方法會保留模型的所有資訊,但有時候,我們僅對部分資訊感興趣,例如僅對模型的權重引數感興趣,那麼就可以通過save_weights()方法進行儲存。
In [14]:model.save_weights('mymodels/mnits_weights') # 儲存模型權重資訊In [15]:
new_model = Sequential([ # 建立新的模型 layers.Dense(256, activation=tf.nn.relu), layers.Dense(128, activation=tf.nn.relu), layers.Dense(64, activation=tf.nn.relu), layers.Dense(32, activation=tf.nn.relu), layers.Dense(10) ] ) new_model.compile(loss='sparse_categorical_crossentropy', optimizer=keras.optimizers.RMSprop()) new_model.load_weights('mymodels/mnits_weights') # 將儲存好的權重資訊載入的新的模型中Out[15]:
<tensorflow.python.training.tracking.util.CheckpointLoadStatus at 0x7f49c42b87d0><