1. 程式人生 > 實用技巧 >梯度下降演算法之方程求解

梯度下降演算法之方程求解

從上個月專攻機器學習,從本篇開始,我會陸續寫機器學習的內容,都是我的學習筆記。

問題

梯度下降演算法用於求數學方程的極大值極小值問題,這篇文章講解如何利用梯度下降演算法求解方程 \(x^5+e^x+3x−3=0\) 的根;

方法

首先來解決第一個問題,從方程的形式我們就能初步判斷,它很可能沒有閉式解。我能想到的最直觀的解決方法就是畫出函式圖,函式圖與 x 軸的交點就是方程的解,那先畫個圖看看

從函式影象大體可以判斷,方程的根在 0 附近,但是很明顯 0 不是方程的根,看圖只能猜出個大概,那怎麼做才能得到更精確的解呢?

有一個可行的方法在 x = 0 附近找一堆很接近的數字,比如 [−0.5:0.05:1][−0.5:0.05:1],一個個代入方程的左邊,看看它的值離 0 有多近

:距離 0 越近,說明我們選取的值離方程的根也越近。數學上定義兩個數距離就是絕對值,但是因為絕對值不便於計算,所以將其替換成等價的差的平方,即 F(x)=(f(x)−0)2F(x)=(f(x)−0)2,以此度量結果距離 0 的程度,稱之為損失函式

我們代入計算得到如下的結果

x: -0.500, f(x): -3.9247, F(x): 15.4034
x: -0.450, f(x): -3.7308, F(x): 13.9191
x: -0.400, f(x): -3.5399, F(x): 12.5310
x: -0.350, f(x): -3.3506, F(x): 11.2263
x: -0.300, f(x): -3.1616, F(x): 9.9958
x: -0.250, f(x): -2.9722, F(x): 8.8338
x: -0.200, f(x): -2.7816, F(x): 7.7372
x: -0.150, f(x): -2.5894, F(x): 6.7048
x: -0.100, f(x): -2.3952, F(x): 5.7369
x: -0.050, f(x): -2.1988, F(x): 4.8346
x: -0.000, f(x): -2.0000, F(x): 4.0000
x: 0.050, f(x): -1.7987, F(x): 3.2354
x: 0.100, f(x): -1.5948, F(x): 2.5434
x: 0.150, f(x): -1.3881, F(x): 1.9268
x: 0.200, f(x): -1.1783, F(x): 1.3883
x: 0.250, f(x): -0.9650, F(x): 0.9312
x: 0.300, f(x): -0.7477, F(x): 0.5591
x: 0.350, f(x): -0.5257, F(x): 0.2763
x: 0.400, f(x): -0.2979, F(x): 0.0888
x: 0.450, f(x): -0.0632, F(x): 0.0040
x: 0.500, f(x): 0.1800, F(x): 0.0324
x: 0.550, f(x): 0.4336, F(x): 0.1880
x: 0.600, f(x): 0.6999, F(x): 0.4898
x: 0.650, f(x): 0.9816, F(x): 0.9635
x: 0.700, f(x): 1.2818, F(x): 1.6431
x: 0.750, f(x): 1.6043, F(x): 2.5738
x: 0.800, f(x): 1.9532, F(x): 3.8151
x: 0.850, f(x): 2.3334, F(x): 5.4445
x: 0.900, f(x): 2.7501, F(x): 7.5630
x: 0.950, f(x): 3.2095, F(x): 10.3008

可以看出,x = 0.5,結果已經很接近 0 了,方程的根應該在 0.45~0.50 之間,而且 0.45 時,F(x) 的值更小,說明離 0.45 距離更近。接下來,一個可行的方法是將這段再細分成更小的區間,再如上面這樣嘗試,直到結果滿意為止。但是這樣做太過機械,每次需要手動調整區間和步長,有沒有一種方法可以自動調整呢?

再回到我們的問題,求解方程的根,就是找到一個點使得損失函式最小,我們畫出來這個函式的曲線看看

我們假定方程的根是 x0x0,除了 x0x0,其他點的函式值都比該點處的高,而且從兩邊向內,越是靠近 x0x0,函式的值越接近 0。而且可以發現,從兩邊向 x0x0 移動,方向剛好就是該點處切線的斜率 F′(x)F′(x) 的相反數。

於是得到啟發,挑選一個初始點,沿著該點的斜率相反的方向迭代,必然越來越靠近方程的根,所以有下面的演算法:

  1. 對於方程 f(x)=0f(x)=0,舍設定損失函式 F(x)=(f(x)−0)2F(x)=(f(x)−0)2;
  2. 設定一個初值 x0x0,代入損失函式求得結果,如果大於 0,那麼找到一個新的值 x1=x0−αF′(x0)x1=x0−αF′(x0),考察損失函式是否為 0;
  3. 反覆迭代第 2 步,直到達到滿意的精度為止。

上面的演算法中,有三個引數需要注意:

  • αα,稱為學習率,代表了曲線逼近的速度,這個引數可以自己設定;
  • 迭代次數,第 2 步執行的次數,迭代次數越多,我們離理想的結果越接近;
  • 精度,定義為 |F(x)||F(x)|,表示迭代的效果

這三個引數中,迭代次數和精度可以作為迭代的終止條件,比如迭代次數達到 10000 次或者精度達到某個很小的數值 σσ 就終止執行。

下面我們使用 python 程式來演示該演算法的效果:

# _*_ coding: utf-8 _*_
import numpy as np

# 定義函式f(x)
    e = 2.71828182845904590
    return x**5 + e**x + 3*x - 3

#定義損失函式
def loss_fun(x):
    return (problem(x) - 0)**2

#計算損失函式的斜率
def slope_fx(x):
    delta  = 0.0000001;
    return (loss_fun(x+delta) - loss_fun(x-delta))/(2.0*delta)

#代入f(x),計算數值
def calcu_loss_fun(x,maxTimes,alpha):
        for i in range(maxTimes):
            x = x - slope_fx(x)*alpha;
            print 'times %d, x: %.13f, f(x): %.13f' % (i, x, problem(x))
alpha = 0.01
maxTimes = 100
x = 0.0;

calcu_loss_fun(x,maxTimes,alpha)

其中的slope_fx計算方程的斜率,利用導數定義 f′(x)=f(x+Δx)−f(x)Δxf′(x)=f(x+Δx)−f(x)Δx。程式計算結果如下

times 1, x: 0.2724712244717, f(x): -0.8678788871194
times 2, x: 0.3478163723702, f(x): -0.5354882897920
times 3, x: 0.3958941025006, f(x): -0.3168805921512
times 4, x: 0.4251012218626, f(x): -0.1810687680246
times 5, x: 0.4420964369242, f(x): -0.1008566369730
times 6, x: 0.4516717013511, f(x): -0.0552506486831
times 7, x: 0.4569525930429, f(x): -0.0299651603458
times 8, x: 0.4598276021739, f(x): -0.0161585445219
times 9, x: 0.4613811940466, f(x): -0.0086856358075
times 10, x: 0.4622172450759, f(x): -0.0046606160693
times 11, x: 0.4626661379649, f(x): -0.0024984737671
times 12, x: 0.4629068614830, f(x): -0.0013387061269
times 13, x: 0.4630358664583, f(x): -0.0007170954782
times 14, x: 0.4631049762781, f(x): -0.0003840652503
times 15, x: 0.4631419923255, f(x): -0.0002056832476
times 16, x: 0.4631618165349, f(x): -0.0001101474736
times 17, x: 0.4631724329502, f(x): -0.0000589848326
times 18, x: 0.4631781181683, f(x): -0.0000315864570
times 19, x: 0.4631811626230, f(x): -0.0000169144811
times 20, x: 0.4631827929259, f(x): -0.0000090576372
times 21, x: 0.4631836659475, f(x): -0.0000048503201
times 22, x: 0.4631841334466, f(x): -0.0000025973198
times 23, x: 0.4631843837899, f(x): -0.0000013908497
times 24, x: 0.4631845178473, f(x): -0.0000007447918
times 25, x: 0.4631845896343, f(x): -0.0000003988315
times 26, x: 0.4631846280757, f(x): -0.0000002135719
times 27, x: 0.4631846486609, f(x): -0.0000001143664
times 28, x: 0.4631846596842, f(x): -0.0000000612425
times 29, x: 0.4631846655870, f(x): -0.0000000327950
times 30, x: 0.4631846687480, f(x): -0.0000000175615
times 31, x: 0.4631846704407, f(x): -0.0000000094041
times 32, x: 0.4631846713471, f(x): -0.0000000050358
times 33, x: 0.4631846718325, f(x): -0.0000000026967
times 34, x: 0.4631846720924, f(x): -0.0000000014440
times 35, x: 0.4631846722316, f(x): -0.0000000007733
times 36, x: 0.4631846723061, f(x): -0.0000000004141
times 37, x: 0.4631846723460, f(x): -0.0000000002217
times 38, x: 0.4631846723674, f(x): -0.0000000001187
times 39, x: 0.4631846723788, f(x): -0.0000000000636
times 40, x: 0.4631846723850, f(x): -0.0000000000340
times 41, x: 0.4631846723882, f(x): -0.0000000000182
times 42, x: 0.4631846723900, f(x): -0.0000000000098
times 43, x: 0.4631846723909, f(x): -0.0000000000052
times 44, x: 0.4631846723914, f(x): -0.0000000000028
times 45, x: 0.4631846723917, f(x): -0.0000000000015
times 46, x: 0.4631846723919, f(x): -0.0000000000008
times 47, x: 0.4631846723919, f(x): -0.0000000000004
times 48, x: 0.4631846723920, f(x): -0.0000000000002
times 49, x: 0.4631846723920, f(x): -0.0000000000001
times 50, x: 0.4631846723920, f(x): -0.0000000000001
times 51, x: 0.4631846723920, f(x): -0.0000000000000
times 52, x: 0.4631846723920, f(x): -0.0000000000000
times 53, x: 0.4631846723920, f(x): -0.0000000000000
times 54, x: 0.4631846723920, f(x): -0.0000000000000

迭代 52 次,就已經達到了理想的效果。

參考資料