1. 程式人生 > 實用技巧 >tensrflow2.0——模型保存於讀取

tensrflow2.0——模型保存於讀取

一、說明

  我們建立好模型之後需要儲存模型,以方便後續對模型的讀取與呼叫,儲存模型我們可能有下面三種需求:1、只儲存模型權重引數;2、同時儲存模型圖結構與權重引數;3、在訓練過程的檢查點儲存模型資料。下面分別對這三種需求進行實現。

二、僅儲存模型引數

  僅儲存模型引數可以用一下的API:

  Model.save_weights(file_path)  # 將檔案儲存到save_path
  Model.load_weights(file_path)  # 將檔案讀取到save_path

  注意:由於save_weights只是儲存權重w、b的引數值,所以在載入時最好保證我們的模型結構和原來儲存的模型結構是相同的,否則可能會報錯。.

  模型在儲存之後會有多個檔案:

  • index型別檔案,在分散式計算中,索引檔案會指示哪些權重儲存在哪個分片。
  • checkpoint型別檔案,檢查檔案點包含:一個或多個包含模型權重的分片
  • 如果您只在一臺機器上訓練模型,那麼您將有一個帶有後綴的分片:.data-00000-of-00001
import tensorflow as tf
import os

# 讀取資料集
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.fashion_mnist.load_data()

# 資料集歸一化
train_images = train_images / 255
train_labels 
= train_labels / 255 # 進行資料的歸一化,加快計算的程序 # 建立模型結構 net_input=tf.keras.Input(shape=(28,28)) fl=tf.keras.layers.Flatten()(net_input)#呼叫input l1=tf.keras.layers.Dense(32,activation="relu")(fl) l2=tf.keras.layers.Dropout(0.5)(l1) net_output=tf.keras.layers.Dense(10,activation="softmax")(l2) # 建立模型類 model = tf.keras.Model(inputs=net_input, outputs=net_output)
# 檢視模型的結構 model.summary() # 模型編譯 model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001), loss="sparse_categorical_crossentropy", metrics=['acc']) # 模型訓練 model.fit(train_images, train_labels, batch_size=50, epochs=5, validation_split=0.1) # 模型存放路徑 save_path = './save_weights/' model.save_weights(save_path) # 模型載入 model.load_weights(save_path) # # 定義一個與原模型結構不同的模型 # net_in=tf.keras.Input(shape=(748,)) # net_out=tf.keras.layers.Dense(10,activation="softmax")(net_in) # # # 用不同結構的模型讀取引數,這裡會報錯 # model2=tf.keras.Model(inputs=net_in,outputs=net_out) # model2.load_weights(save_path)

三、同時儲存結構與引數

  Keras使用HDF5標準提供基本儲存格式,出於我們的目的,可以將儲存的模型視為單個二進位制blob。

  儲存完整的模型非常有用,使我們可以在TensorFlow.js(HDF5,Saved Model) 中載入它們,然後在Web瀏覽器中訓練和執行它們,或者使用TensorFlow Lite(HDF5,Saved Model)將它們轉換為在移動裝置上執行。

# 模型訓練
model.fit(train_images, train_labels, batch_size=50, epochs=5, validation_split=0.1)

# 儲存模型
model.save('net_model.h5')

# 模型載入
new_model=tf.keras.models.load_model('net_model.h5')

四、在訓練過程的檢查點儲存模型資料

  在訓練過程的檢查點儲存模型資料有兩個作用:1、我們可以儲存訓練各個節點的資料,便於我們把訓練效果最好的節點的模型挑選出來。2、可以隨時先暫停訓練模型,當想要訓練時繼續訓練。

  在訓練的檢查點儲存模型需要用到tf.keras.callbacks.ModelCheckpoint()類,這個是一個回撥類,可以以列表形式傳入到fit()方法的callbacks引數中。

  回撥中類,檔名以.ckpt作為字尾,如檔案路徑'./checkpoint/train.ckpt',會在checkpoint生成三個檔案,字尾與Model.save_weights()方法建立的檔案字尾相同,意義也相同。以下為回撥類的引數:

  tf.keras.callbacks.ModelCheckpoint()

  • filepath:string,儲存模型檔案的路徑。
  • monitor:監控:要監控的數量。
  • verbose詳細:詳細模式,0或1。
  • save_best_only:如果save_best_only = True,則不會覆蓋根據監控數量的最新最佳模型。
  • save_weights_only:如果為True,則只有模型的權重
    儲存(model.save_weights(filepath)),否則儲存完整模型(model.save(filepath))。
  • mode:{auto,min,max}之一。 如果save_best_only =
    True,則根據監控數量的最大化或最小化來決定覆蓋當前儲存檔案。
    對於val_acc,這應該是max,對於val_loss,這應該是min等。在自動模式下,從監控量的名稱自動推斷方向。
  • period:檢查點之間的間隔(時期數)。

  

# 模型編譯
model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=0.001),
              loss="sparse_categorical_crossentropy",
              metrics=['acc'])

# 建立一個儲存模型的回撥函式,每5個週期儲存一次權重
cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath='./checkpoint/train.ckpt',
    verbose=1,
    save_weights_only=True,
    period=5
)

# 模型訓練
model.fit(train_images, train_labels, batch_size=50, epochs=5, validation_split=0.1, callbacks=[cp_callback])

# 載入模型
model.load_weights('./checkpoint/train.ckpt')

# # 繼續訓練模型
# model.fit()