終於搞懂了Keras中multiloss的對應關係介紹
阿新 • • 發佈:2020-06-22
我就廢話不多說了,大家還是直接看程式碼吧~
model = Model(inputs=[src,tgt],outputs=[y,flow]) #定義網路的時候會給出輸入和輸出 model.compile(optimizer=Adam(lr=lr),loss=[ losses.cc3D(),losses.gradientLoss('l2')],loss_weights=[1.0,reg_param]) #訓練網路的時候指定loss,如果是多loss, loss weights分別對應前面的每個loss的權重,最後輸出loss的和 train_loss = model.train_on_batch( [X,atlas_vol],[atlas_vol,zero_flow]) #開始訓練,loss中y_pred 和y_true的對應關係是: #輸出y與atlas_vol算cc3Dloss,輸出flow與zero_flow算gradientloss
補充知識:keras伺服器用fit_generator跑的程式碼,loss,acc曲線圖的儲存
我就廢話不多說了,大家還是直接看程式碼吧~
import matplotlib.pyplot as plt ... //資料處理程式碼 省略 history = model.fit_generator( image_generator,steps_per_epoch=2000 // 32,epochs=16,verbose=1,validation_data=image_generator_TEST,validation_steps=20 ) print(history.history.keys()) plt.switch_backend('agg') #伺服器上面儲存圖片 需要設定這個 //acc plt.plot(history.history['acc']) plt.plot(history.history['val_acc']) plt.title('model accuracy') plt.ylabel('accuracy') plt.xlabel('epoch') plt.legend(['train','test'],loc='upper left') plt.savefig('acc.jpg') //loss plt.plot(history.history['loss']) plt.plot(history.history['val_loss']) plt.ylabel('loss') plt.xlabel('epoch') plt.legend(['train',loc='upper left') plt.savefig('loss.jpg')
以上這篇終於搞懂了Keras中multiloss的對應關係介紹就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。