1. 程式人生 > 程式設計 >終於搞懂了Keras中multiloss的對應關係介紹

終於搞懂了Keras中multiloss的對應關係介紹

我就廢話不多說了,大家還是直接看程式碼吧~

model = Model(inputs=[src,tgt],outputs=[y,flow])  
#定義網路的時候會給出輸入和輸出
model.compile(optimizer=Adam(lr=lr),loss=[
           losses.cc3D(),losses.gradientLoss('l2')],loss_weights=[1.0,reg_param]) 
#訓練網路的時候指定loss,如果是多loss,
loss weights分別對應前面的每個loss的權重,最後輸出loss的和
train_loss = model.train_on_batch(
      [X,atlas_vol],[atlas_vol,zero_flow]) 
 #開始訓練,loss中y_pred 和y_true的對應關係是:
 #輸出y與atlas_vol算cc3Dloss,輸出flow與zero_flow算gradientloss

補充知識:keras伺服器用fit_generator跑的程式碼,loss,acc曲線圖的儲存

我就廢話不多說了,大家還是直接看程式碼吧~

import matplotlib.pyplot as plt

...  //資料處理程式碼 省略

history = model.fit_generator(
  image_generator,steps_per_epoch=2000 // 32,epochs=16,verbose=1,validation_data=image_generator_TEST,validation_steps=20
)

print(history.history.keys())
plt.switch_backend('agg')  #伺服器上面儲存圖片 需要設定這個
//acc
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train','test'],loc='upper left')
plt.savefig('acc.jpg')
//loss
plt.plot(history.history['loss'])
plt.plot(history.history['val_loss'])
plt.ylabel('loss')
plt.xlabel('epoch')
plt.legend(['train',loc='upper left')
plt.savefig('loss.jpg')

以上這篇終於搞懂了Keras中multiloss的對應關係介紹就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。