1. 程式人生 > >Keras深度學習框架之損失函式

Keras深度學習框架之損失函式

一.損失函式的使用

  損失函式【也稱目標函式或優化評分函式】是編譯模型時所需的兩個引數之一。

  model.compile(loss='mean_squared_error', optimizer='sgd')

  或

  from keras import losses

  model.compile(loss=losses.mean_squared_error, optimizer='sgd')

  可以傳遞一個現有的損失函式名或者一個TensorFlow/Theano符號函式。該符號函式為每個資料點返回一個標量,有一下兩個引數:

  1.y_true

    真實標籤,TensorFlow/Theano張量。

  2.y_pred

    預測值,TensorFlow/Theano張量,其shape與y_true相同。

  實際的優化目標是所有資料點的輸出陣列的平均值。

二.可用的損失函式

  1.mean_squared_error(y_true, y_pred)【MSE,均方誤差】

    計算公式:

    

    原始碼:

    

  2.mean_absolute_error(y_true, y_pred)【MAE,平均絕對誤差】

    提到MAE就不能不說顯著性目標檢測,所謂顯著性目標,舉個例子來說,當我們觀察一張圖片時,我們會首先關注那些顏色鮮明,奪人眼球的內容。就像我們看變形金剛時會首先看擎天柱一樣,這是絕對的C位。所以我們把變形金剛中的擎天柱定義為顯著性目標。

    在顯著性目標檢測中的評價指標計算中,常用的檢測演算法就有平均絕對誤差,其計算公式如下:

    

    原始碼:

    

  3.mean_absolute_percentage_error【MAPE,平均絕對百分比誤差】

    與平均絕對誤差類似,平均絕對百分比誤差預測結果與真實值之間的偏差比例。計算公式如下:

    

    原始碼:

    

    備註:

    1.clip

      逐元素,將超出指定範圍的數強制變為邊界數。

    2.epsilon

      固定引數,預設值為1*e-7。

  4.mean_squared_logarithmic_error【MSLE,均方對數誤差】

    在計算均方誤差之前先對資料取對數,再計算。

    計算公式:

    

     原始碼:

    

  5.squared_hinage【不常用】

    計算公式:

    

    原始碼:

    

  6.hinage【不常用】

    計算公式:

    

    原始碼:

    

  7.categorical_hinge【不常用】

    原始碼:

    

  8.logcosh【不常用】

    預測誤差的雙曲餘弦的對數。計算結果與均方誤差大致相同,但不會受到偶爾瘋狂的錯誤預測的強烈影響。

    原始碼:

    

  9.categorical_crossentropy【不常用】

    當使用categorical_crossentropy損失時,目標值應該是分類格式【即假如是10類,那麼每個樣本的目標值應該是一個10維的向量,這個向量除了表示類別的那個索引為1,其它均為0】。為了將整數目標值轉換為分類目標值,可以使用keras實用函式to_categorical。

    from keras.utils.np_utils import to_categorical

    categorical_labels = to_categorical(int_labels, num_classes=None)

    原始碼:

    

   10.sparse_categorical_crossentropy【不常用】

    原始碼:

    

   11.binary_crossentropy【不常用】

    原始碼:

    

  12.kullback_leibler_divergence【不常用】

    原始碼:

    

  13.poisson【不常用】

    計算公式:

    

     原始碼:

    

  14.cosine_proximity【不常用】

    計算公式:

    

    原始碼:

    

三.其它型別的損失函式

  1.ctc_batch_cost【高效能】

    原始碼:

    

 

    在每個批處理元素上執行CTC損失演算法。

    引數:

    1.y_true

      包含真實值標籤的張量。型別(samples, max_string_length).

    2.y_pred

      包含預測值或softmax輸出的張量。型別(samples, time_steps, num_categories)。

    3.input_length

      張量(samples, 1),包含y_pred中每個批處理項的序列長度。

    4.label_length

      張量(samples, 1), 包含y_true中每個批處理項的序列長度。

    返回shape為(samples, 1)的張量,包含每一個元素的CTC損