1. 程式人生 > 實用技巧 >走進神經網路——02.梯度下降演算法介紹

走進神經網路——02.梯度下降演算法介紹

  梯度下降演算法在神經網路的訓練中扮演著非常重要的角色,在本文中,我們來仔細介紹一下梯度下降演算法。假設我們遇到了一個函式,其方程為F(x)=x2+10x+5,現在我們要去求它的最值(最大值或者最小值),如果讓我們自己來求,我們可以很容易看出該函式存在最小值,並且在對稱軸x=-b/2a=-5處取得最小值,最小值為-20。但是如果讓計算機去求它的最小值應該怎樣做呢?這就是我們今天要介紹的內容。

  我們畫出該方程的影象,如圖1-2所示:

圖1-2 函式F(x)=x2+10x+5的影象

從圖中可以看出,函式的最小值就是M點處,但是計算機應該怎樣找到M點,首先我們在該圖上隨便找一點,該點為A點,然後沿著A點,往下走一步,此時就到了B點,緊接著再往下走一點,就回到達C點,就這樣一步一步地走,總會到達我們要找的最小值點,也就是M點。可能這時候有人會問,在找到A點後,我們怎麼能夠判斷往哪邊走是向下走?如果往A的左邊走一直就會向上,這樣永遠就走不到M點了啊,此時我們需要考慮一下,怎樣才能保證每一步都是在向下走呢?這時候就輪到你在學校時天天抱怨學了沒用的導數起作用了,可能有人疑惑,導數在這裡有什麼用呢,那讓我們回憶一下導數的幾何意義,導數就是影象在該點處切線的斜率,也就是函式值上升的方向。等等,上升的方向?那麼給它添負號不就變成了函式值下降的方向,沒錯函式的導數,也稱為梯度,加負號就是函式值下降的方向。方向問題解決了,但還有一個問題,每次向下走一步,這一步到底是多大,在這裡我們會用一個常數a來表示這一步的大小,a被稱為學習率,是人為給定的,這個值不能太大也不能太小,太大了容易跨過最小值M,太小了向下走次數太多,不容易在短時間內找到M點。我們怎樣知道要找的M是否已經找到呢?當然是當該點的導數是否為0來決定,下面我們簡述一下梯度下降演算法:

1.隨便取一個值x0。

2.判斷函式在x0處的導數是否為0。

3.如果為0,函式的最小值點就是x0,如果不等於0,x0=x0-a*F1(x0)返回第2步。

接下來我們使用程式來實現:

public class Gradient {
    public static void main(String[] args) {
        float x0 = 1.0f;
        //學習率
        float a = 0.2f;
        int i=0;
        while (true) {
            //求函式F(x)=x2+10x+5在x0處的導數
            float
g=2*x0+10; System.out.println(""+i+"次的導數(梯度)為"+g); //一般不能精確到0,足夠小即可認為等於0 if (Math.abs(g)<0.001) { break; } x0=(x0-a*g); i++; } System.out.println("函式的最小值點為"+x0+",最小值為"+x0*x0+10*x0+5); } }

程式執行結果如下:

從圖中我們可以看到程式執行的結果很接近最小值點-5和最小值25,隨著迭代次數的增加,結果會越來越準確。