1. 程式人生 > >tensflow自定義損失函式

tensflow自定義損失函式

三、自定義損失函式

標準的損失函式並不合適所有場景,有些實際的背景需要採用自己構造的損失函式,Tensorflow 也提供了豐富的基礎函式供自行構建。
例如下面的例子:當預測值(y_pred)比真實值(y_true)大時,使用 (y_pred-y_true)*loss_more 作為 loss,反之,使用 (y_true-y_pred)*loss_less

loss = tf.reduce_sum(tf.where(tf.greater(y_pred, y_true), (y_pred-y_true)*loss_more,(y_true-y_pred)*loss_less))

tf.greater(x, y):判斷 x 是否大於 y,當維度不一致時廣播後比較
tf.where(condition, x, y):當 condition 為 true 時返回 x,否則返回 y 
tf.reduce_mean():沿維度求平均
tf.reduce_sum():沿維度相加
tf.reduce_prod():沿維度相乘
tf.reduce_min():沿維度找最小
tf.reduce_max():沿維度找最大
使用 Tensorflow 提供的方法可自行構造想要的損失函式。