1. 程式人生 > >keras 網路結構和權值檔案的儲存與載入

keras 網路結構和權值檔案的儲存與載入

以lenet網路為例

from keras.models import Model
from keras.layers import Input,Conv2D,Dense,Flatten,MaxPooling2D,Activation
####################### net structure ##########################
inp=Input(shape=(28,28,1))
out=Conv2D(20,(5,5))(inp)
out=MaxPooling2D(pool_size(2,2))(out)
out=Conv2D(50,(5,5))(inp)
out=MaxPooling2D(pool_size(2,2))(out)
out=Flatten()(out)
out=Dense(500)(out)
out=Activation('relu')(out)
out=Dense(2)(out)
out=Activation('softmax')(out)
#################################################################
model=Model(inp,out)
model.compile(
			optimizer='sgd',
			loss='mse',
			metrics=["accuracy"]
			)

##一、網路結構
儲存網路結構和權值(save)
利用Model類的save函式,如

model.save('./model_name.h5')

載入模型(load_model)

from keras.models import load_model
model=load_model('./model_name.h5')

##二、權值檔案
儲存權值檔案(save_weights)

model.fit(input,output,epochs=100,batch_size=64)
#權值的儲存在訓練完成之後
model.save_weights('./weights.h5')

載入權值檔案(load_weights)

model=load_model('./model_name.h5')
model.load_weights('./weights.h5')
#載入後就可以預測了
result=model.predict(input,batch_size=1)