1. 程式人生 > >003-keras模型的保存於載入

003-keras模型的保存於載入

方式一: 儲存所有狀態

儲存模型和模型圖

# 儲存模型 model.save(file_path)
model_name = '{}/{}_{}_{}_v2.h5'.format(params['model_dir'],params['filters'],params['pool_size_1'],params['pool_size_2'])
model.save(model_name)

# 儲存模型圖
from keras.utils import plot_model
# 需要安裝pip install pydot
model_plot = '{}/{}_{}_{}_v2.png'.format(params['model_dir'],params['filters'],params['pool_size_1'],params['pool_size_2'])
plot_model(model, to_file=model_plot)

 

  • 模型圖如圖所示

儲存的模型圖

載入模型

from keras.models import load_model

model_path = '../docs/keras/100_2_3_v2.h5'
model = load_model(model_path)

  

優勢和弊端
優勢一在於模型儲存和載入就一行程式碼,寫起來很方便。
優勢二在於不僅儲存了模型的結構和引數,也儲存了訓練配置等資訊。以便於從上次訓練中斷的地方繼續訓練優化。
劣勢就是佔空間太大,我的模型用這種方式佔了一個G。【紅色部分就是上述模型採用第一種方式儲存的檔案】本地使用還好,如果是多人的模組需要整合,上傳或者同步將會很耗時。

 

 

方式二: 只儲存模型結構和模型引數

儲存模型

儲存模型圖部分和方式一相同。

import yaml
import json

# 儲存模型結構到yaml檔案或者json檔案
yaml_string = model.to_yaml()
open('../docs/keras/model_architecture.yaml', 'w').write(yaml_string)
# json_string = model.to_json()
# open('../docs/keras/model_architecture.json', 'w').write(json_string)

# 儲存模型引數到h5檔案
model.save_weights('../docs/keras/model_weights.h5')

  

載入模型

import yaml
import json
from keras.models import model_from_json
from keras.models import model_from_yaml

# 載入模型結構
model = model_from_yaml(open('../docs/keras/model_architecture.yaml').read())
# model = model_from_json(open('../docs/keras/model_architecture.json').read())

# 載入模型引數
model.load_weights('../docs/keras/model_weights.h5')

  

優勢和弊端

  • 優勢就是節省了硬碟空間,方便同步和協作
  • 劣勢是丟失了訓練的一些配置資訊