[ Keras ] ——基本使用:(4) 模型記錄點實時儲存方法
阿新 • • 發佈:2019-01-01
for epo in range(1,101): # 迭代次數為100次
print('epochs:', epo)
model1.fit(X_train, y_train, batch_size=batch_size, epochs=1, validation_data=(X_test, y_test), shuffle=True)
print('\n第', epo, '個epoch完成\n')
model1.save('CIFAR10_model_' + 'epoch_' + str(epo) + '.h5')
或者
per_epochs = 10 for i_epo in range(4,20): # i_epo = 3 表示讀取第(3-1)*per_epochs = 20代的引數 print('第',i_epo*per_epochs,'代:\n') if i_epo != 1: model.load_weights( project_path + 'Record_point/baseline_res_point/' + 'Baseline_model_weight_epoch_' + str((i_epo-1)*per_epochs) + '.h5') hist = model.fit_generator(gen, steps_per_epoch=808, epochs=per_epochs, validation_data=(x_val, y_val)) # 主模型訓練用這個(real-time) with open(project_path + 'Record_point/baseline_res_point/result_save/log_baseline_' + str(per_epochs*i_epo) + '.txt', 'w') as f: f.write(str(hist.history)) model.save_weights( project_path + 'Record_point/baseline_res_point/' + 'Baseline_model_weight_epoch_' + str(per_epochs*i_epo) + '.h5')