1. 程式人生 > >[ Keras ] ——基本使用:(4) 模型記錄點實時儲存方法

[ Keras ] ——基本使用:(4) 模型記錄點實時儲存方法

for epo in range(1,101):  # 迭代次數為100次
    print('epochs:', epo)
    model1.fit(X_train, y_train, batch_size=batch_size, epochs=1, validation_data=(X_test, y_test), shuffle=True)
    print('\n第', epo, '個epoch完成\n')
    model1.save('CIFAR10_model_' + 'epoch_' + str(epo) + '.h5')

或者

per_epochs = 10
for i_epo in range(4,20):  # i_epo = 3 表示讀取第(3-1)*per_epochs = 20代的引數
    print('第',i_epo*per_epochs,'代:\n')
    if i_epo != 1:
        model.load_weights( project_path + 'Record_point/baseline_res_point/' +
                           'Baseline_model_weight_epoch_' + str((i_epo-1)*per_epochs) + '.h5')
    hist = model.fit_generator(gen, steps_per_epoch=808, epochs=per_epochs, validation_data=(x_val, y_val))  # 主模型訓練用這個(real-time)
    with open(project_path + 'Record_point/baseline_res_point/result_save/log_baseline_' + str(per_epochs*i_epo) + '.txt', 'w') as f:
        f.write(str(hist.history))

    model.save_weights( project_path + 'Record_point/baseline_res_point/' +
                   'Baseline_model_weight_epoch_' + str(per_epochs*i_epo) + '.h5')