1. 程式人生 > >如何在Keras中儲存已經訓練好的模型

如何在Keras中儲存已經訓練好的模型

之前我們討論過《》的方法,彼時我們儲存的僅僅是NN中的模型引數,但是在Keras中我們可以非常優雅地把整個模型(包括已經訓練好的引數和神經網路的結構)儲存起來,而且這一切都“非常非常”簡單。

作為一個例子,這裡使用《》中給出的對MINST資料集進行手寫數字識別的程式碼,並將其中訓練好的模型進行儲存。

需要提前說明的一點是Keras會把模型儲存成“.h5”檔案,為了讓你的程式可以支援這種形式的檔案你需要安裝一下h5py這個package,對此你可以在命令列下輸入:sudo pip install h5py

在你確認完成了上面的步驟之後,接下來的內容就非常簡單, 只需要下面這樣的語句,已經建立好的模型就會被成功儲存了。

from keras.models import load_model

model.save('fei_model.h5')

如下圖所示,你會看到“fei_model.h5”檔案已經生成。完整的包括模型儲存部分的手寫數字識別程式碼的jupyter notebook檔案可以從文末給出的連結中下載到。


載入已經儲存好的模型同樣非常簡單。你可以另外開一個新檔案,匯入必要的包和資料(當然無需重新定義和訓練模型),然後使用下面這樣的語句:

from keras.models import load_model
my_model = load_model('/home/airobot/Desktop/fei_model.h5')
被載入的模型是否可以使用呢?我們用它來做一下predict,程式碼如下:
pred = my_model.predict(X_test_0[:])
print('Label of testing sample', np.argmax(y_test_0))
print('Output of the softmax layer', pred[0])
print('Network prediction:', np.argmax([pred[0]]))
如果你有讀之前的文章《》,可知下面的輸出表明,程式準確地預測了手寫數字為7。
('Label of testing sample', 7)
('Output of the softmax layer', array([  1.55717719e-23,   1.71202164e-13,   5.04616253e-14,
         7.39000111e-11,   3.02617696e-22,   3.59160995e-19,
         9.96468490e-34,   1.00000000e+00,   4.57078810e-19,
         2.00005355e-12], dtype=float32))
('Network prediction:', 7)

最後,本文中所使用的原始碼之jupyter notebook檔案均可從如下連結中獲取:

https://pan.baidu.com/s/1nvdxYC5

(本文完)