1. 程式人生 > >keras儲存模型和載入模型

keras儲存模型和載入模型

1、儲存模型和載入模型的方法

用實驗室的伺服器跑神經網路的時候伺服器老是斷開連線,這對我的訓練和測試來時是一件比較崩潰的事,因為這意味著我要重新訓練一次,要浪費又一次的時間,所以我在網上百度了儲存模型和載入模型的辦法,大部分的方法如下:

儲存模型

model.save('my_model.h5')      

載入模型

model = load_model('my_model.h5')

但是由於我的WbW,b和其他的一些優化引數是自定義的,所以在載入模型的時候就出現了一系列的問題,例如WW沒有初始化之類的問題,所以這時候,儲存模型就不如儲存權重來的方便,所以我就採用了儲存權重這種方法。

儲存模型的權重

model.save_weights('my_model.h5')

載入模型的權重

model.load_weights('my_model.h5')

2、儲存模型的方式

眾做周知,我們是在訓練模型的時候儲存模型的引數和超引數,在此基礎上進行引數的調整和調優,但是我們最好不要在每一個epoch裡都進行模型的儲存,因為儲存模型是一個耗費時間的操作,每次儲存模型的時候是要佔用CPU的時間的,這時候就浪費了GPU的資源,並且將每一次的引數都儲存下來也是不必要行為,我們可以採用以下幾種方法進行儲存:

  • 1、每隔50個epoch儲存
  • 2、當loss值比上一次小的時候儲存
  • 3、當正確率比上一次高的時候儲存