keras 網路結構和權值檔案的儲存與載入
阿新 • • 發佈:2019-02-07
以lenet網路為例
from keras.models import Model from keras.layers import Input,Conv2D,Dense,Flatten,MaxPooling2D,Activation ####################### net structure ########################## inp=Input(shape=(28,28,1)) out=Conv2D(20,(5,5))(inp) out=MaxPooling2D(pool_size(2,2))(out) out=Conv2D(50,(5,5))(inp) out=MaxPooling2D(pool_size(2,2))(out) out=Flatten()(out) out=Dense(500)(out) out=Activation('relu')(out) out=Dense(2)(out) out=Activation('softmax')(out) ################################################################# model=Model(inp,out) model.compile( optimizer='sgd', loss='mse', metrics=["accuracy"] )
##一、網路結構
儲存網路結構和權值(save)
利用Model類的save函式,如
model.save('./model_name.h5')
載入模型(load_model)
from keras.models import load_model
model=load_model('./model_name.h5')
##二、權值檔案
儲存權值檔案(save_weights)
model.fit(input,output,epochs=100,batch_size=64)
#權值的儲存在訓練完成之後
model.save_weights('./weights.h5')
載入權值檔案(load_weights)
model=load_model('./model_name.h5')
model.load_weights('./weights.h5')
#載入後就可以預測了
result=model.predict(input,batch_size=1)