1. 程式人生 > >如何選擇迴歸損失函式

如何選擇迴歸損失函式

無論在機器學習還是深度領域中,損失函式都是一個非常重要的知識點。損失函式(Loss Function)是用來估量模型的預測值 f(x) 與真實值 y 的不一致程度。我們的目標就是最小化損失函式,讓 f(x) 與 y 儘量接近。通常可以使用梯度下降演算法尋找函式最小值。

損失函式有許多不同的型別,沒有哪種損失函式適合所有的問題,需根據具體模型和問題進行選擇。一般來說,損失函式大致可以分成兩類:迴歸(Regression)和分類(Classification)。

迴歸模型中的三種損失函式包括:均方誤差(Mean Square Error)、平均絕對誤差(Mean Absolute Error,MAE)、Huber Loss。

1. 均方誤差(Mean Square Error,MSE)

均方誤差指的就是模型預測值 f(x) 與樣本真實值 y 之間距離平方的平均值。其公式如下所示:

其中,yi 和 f(xi) 分別表示第 i 個樣本的真實值和預測值,m 為樣本個數。

為了簡化討論,忽略下標 i,m = 1,以 y-f(x) 為橫座標,MSE 為縱座標,繪製其損失函式的圖形:

MSE 曲線的特點是光滑連續、可導,便於使用梯度下降演算法,是比較常用的一種損失函式。而且,MSE 隨著誤差的減小,梯度也在減小,這有利於函式的收斂,即使固定學習因子,函式也能較快取得最小值。

平方誤差有個特性,就是當 yi 與 f(xi) 的差值大於 1 時,會增大其誤差;當 yi 與 f(xi) 的差值小於 1 時,會減小其誤差。這是由平方的特性決定的。也就是說, MSE 會對誤差較大(>1)的情況給予更大的懲罰,對誤差較小(<1)的情況給予更小的懲罰。從訓練的角度來看,模型會更加偏向於懲罰較大的點,賦予其更大的權重。

如果樣本中存在離群點,MSE 會給離群點賦予更高的權重,但是卻是以犧牲其他正常資料點的預測效果為代價,這最終會降低模型的整體效能。我們來看一下使用 MSE 解決含有離群點的迴歸模型。

import numpy as np
import matplotlib.pyplot as plt
x = np.linspace(1, 20, 40)
y = x + [np.random.choice(4) for _ in range(40)]
y[-5:] -= 8
X = np.vstack((np.ones_like(x),x))    # 引入常數項 1
m = X.shape[1]
# 引數初始化
W = np.zeros((1,2))

# 迭代訓練 
num_iter = 20
lr = 0.01
J = []
for i in range(num_iter):
   y_pred = W.dot(X)
   loss = 1/(2*m) * np.sum((y-y_pred)**2)
   J.append(loss)
   W = W + lr * 1/m * (y-y_pred).dot(X.T)

# 作圖
y1 = W[0,0] + W[0,1]*1
y2 = W[0,0] + W[0,1]*20
plt.scatter(x, y)
plt.plot([1,20],[y1,y2])
plt.show()

擬合結果如下圖所示:

可見,使用 MSE 損失函式,受離群點的影響較大,雖然樣本中只有 5 個離群點,但是擬合的直線還是比較偏向於離群點。這往往是我們不希望看到的。

2. 平均絕對誤差(Mean Absolute Error,MAE)

平均絕對誤差指的就是模型預測值 f(x) 與樣本真實值 y 之間距離的平均值。其公式如下所示:

為了簡化討論,忽略下標 i,m = 1,以 y-f(x) 為橫座標,MAE 為縱座標,繪製其損失函式的圖形:

直觀上來看,MAE 的曲線呈 V 字型,連續但在 y-f(x)=0 處不可導,計算機求解導數比較困難。而且 MAE 大部分情況下梯度都是相等的,這意味著即使對於小的損失值,其梯度也是大的。這不利於函式的收斂和模型的學習。

值得一提的是,MAE 相比 MSE 有個優點就是 MAE 對離群點不那麼敏感,更有包容性。因為 MAE 計算的是誤差 y-f(x) 的絕對值,無論是 y-f(x)>1 還是 y-f(x)<1,沒有平方項的作用,懲罰力度都是一樣的,所佔權重一樣。針對 MSE 中的例子,我們來使用 MAE 進行求解,看下擬合直線有什麼不同。

X = np.vstack((np.ones_like(x),x))    # 引入常數項 1
m = X.shape[1]
# 引數初始化
W = np.zeros((1,2))

# 迭代訓練 
num_iter = 20
lr = 0.01
J = []
for i in range(num_iter):
   y_pred = W.dot(X)
   loss = 1/m * np.sum(np.abs(y-y_pred))
   J.append(loss)
   mask = (y-y_pred).copy()
   mask[y-y_pred > 0] = 1
   mask[mask <= 0] = -1
   W = W + lr * 1/m * mask.dot(X.T)

# 作圖
y1 = W[0,0] + W[0,1]*1
y2 = W[0,0] + W[0,1]*20
plt.scatter(x, y)
plt.plot([1,20],[y1,y2],'r--')
plt.xlabel('x')
plt.ylabel('y')
plt.title('MAE')
plt.show()

注意上述程式碼中對 MAE 計算梯度的部分。

擬合結果如下圖所示:

顯然,使用 MAE 損失函式,受離群點的影響較小,擬合直線能夠較好地表徵正常資料的分佈情況。這一點,MAE 要優於 MSE。二者的對比圖如下:

選擇 MSE 還是 MAE 呢?

實際應用中,我們應該選擇 MSE 還是 MAE 呢?從計算機求解梯度的複雜度來說,MSE 要優於 MAE,而且梯度也是動態變化的,能較快準確達到收斂。但是從離群點角度來看,如果離群點是實際資料或重要資料,而且是應該被檢測到的異常值,那麼我們應該使用MSE。另一方面,離群點僅僅代表資料損壞或者錯誤取樣,無須給予過多關注,那麼我們應該選擇MAE作為損失。

3. Huber Loss

既然 MSE 和 MAE 各有優點和缺點,那麼有沒有一種啟用函式能同時消除二者的缺點,集合二者的優點呢?答案是有的。Huber Loss 就具備這樣的優點,其公式如下:

Huber Loss 是對二者的綜合,包含了一個超引數 δ。δ 值的大小決定了 Huber Loss 對 MSE 和 MAE 的側重性,當 |y−f(x)| ≤ δ 時,變為 MSE;當 |y−f(x)| > δ 時,則變成類似於 MAE,因此 Huber Loss 同時具備了 MSE 和 MAE 的優點,減小了對離群點的敏感度問題,實現了處處可導的功能。

通常來說,超引數 δ 可以通過交叉驗證選取最佳值。下面,分別取 δ = 0.1、δ = 10,繪製相應的 Huber Loss,如下圖所示:

Huber Loss 在 |y−f(x)| > δ 時,梯度一直近似為 δ,能夠保證模型以一個較快的速度更新引數。當 |y−f(x)| ≤ δ 時,梯度逐漸減小,能夠保證模型更精確地得到全域性最優值。因此,Huber Loss 同時具備了前兩種損失函式的優點。

下面,我們用 Huber Loss 來解決同樣的例子。

X = np.vstack((np.ones_like(x),x))    # 引入常數項 1
m = X.shape[1]
# 引數初始化
W = np.zeros((1,2))

# 迭代訓練 
num_iter = 20
lr = 0.01
delta = 2
J = []
for i in range(num_iter):
   y_pred = W.dot(X)
   loss = 1/m * np.sum(np.abs(y-y_pred))
   J.append(loss)
   mask = (y-y_pred).copy()
   mask[y-y_pred > delta] = delta
   mask[mask < -delta] = -delta
   W = W + lr * 1/m * mask.dot(X.T)

# 作圖
y1 = W[0,0] + W[0,1]*1
y2 = W[0,0] + W[0,1]*20
plt.scatter(x, y)
plt.plot([1,20],[y1,y2],'r--')
plt.xlabel('x')
plt.ylabel('y')
plt.title('MAE')
plt.show()

注意上述程式碼中對 Huber Loss 計算梯度的部分。

擬合結果如下圖所示:

可見,使用 Huber Loss 作為啟用函式,對離群點仍然有很好的抗干擾性,這一點比 MSE 強。另外,我們把這三種損失函式對應的 Loss 隨著迭代次數變化的趨勢繪製出來:

MSE:

MAE:

Huber Loss:

對比發現,MSE 的 Loss 下降得最快,MAE 的 Loss 下降得最慢,Huber Loss 下降速度介於 MSE 和 MAE 之間。也就是說,Huber Loss 彌補了此例中 MAE 的 Loss 下降速度慢的問題,使得優化速度接近 MSE。

最後,我們把以上介紹的迴歸問題中的三種損失函式全部繪製在一張圖上。

好了,以上就是紅色石頭對迴歸問題 3 種常用的損失函式包括:MSE、MAE、Huber Loss 的簡單介紹和詳細對比。這些簡單的知識點你是否已經完全掌握了呢?