1. 程式人生 > >二叔看ML第一:梯度下降

二叔看ML第一:梯度下降


原理

梯度下降是一個很常見的通過迭代求解函式極值的方法,當函式非常複雜,通過求導尋找極值很困難時可以通過梯度下降法求解。梯度下降法流程如下:
image.png-10.8kB
上圖中,用大寫字母表示向量,用小寫字母表示標量。

假設某人想入坑,他站在某點,他每移動一小步,都朝著他所在點的梯度的負方向移動,這樣能保證他儘快入坑,因為某個點的梯度方向是最陡峭的方向,如下圖所示,此圖畫的不太能表達這個觀點,但是懶得盜圖了,意會吧:
54

以下舉兩個例子,兩個例子中的被求函式都很簡單,其實直接求導算極值更好,此處僅用來說明梯度下降法的步驟。

實踐一:求\(y = x^2 - 4x + 1\)的最小值

# -*- coding: utf-8 -*-

import numpy as np
import matplotlib.pyplot as plt

def descent(p, original_x = 50, steplength = 0.01):
    ''' gradient descent, return min y '''

    deriv = p.deriv(m = 1) # 多項式p的導函式
    Y = [] # 儲存每次迭代後的y值,方便繪圖
    count = 0 # 迭代次數
    x = original_x # 設定x初始值
    d = deriv(x) # x位置的導數
    threshold = 0.001 # 閾值,當梯度小於此值時停止迭代

    while np.abs(d) > threshold:
        x = x - d * steplength
        y = p(x)

        Y.append(y)
        count += 1
        d = deriv(x)

    plt.plot(np.arange(1, count + 1), Y)
    plt.show()
    return y

if __name__ == "__main__":
    p = np.poly1d([2, -4, 1])
    min_y = descent(p)
    print(min_y)  

把迭代數和對應的函式值繪製出來以檢視迭代效果:
image.png-15.1kB

實踐二:求\(z = x^2 + y^2 + 5\)的最小值

以下程式碼中,把一組x和y當成一個向量處理,即\(z = X^TX + 5\),其中\(X=[x\ y]^T\)

import numpy as np
import matplotlib.pyplot as plt

def deriv(xy):
    dxy = 2 * xy
    return dxy

def descent(xy, steplength = 0.01):
    ''' gradient descent, return min y '''

    d = deriv(xy) # x^2 + y^2 + 5的梯度
    Y = [] # 儲存每次迭代後的y值,方便繪圖
    count = 0 # 迭代次數
    threshold = 0.001 # 閾值,當梯度的模小於此值時停止迭代

    while np.linalg.norm(d) > threshold:
        xy = xy - d * steplength
        y = np.dot(xy, xy) + 5

        Y.append(y)
        count += 1 
        d = deriv(xy)

    plt.plot(np.arange(1, count + 1), Y)
    plt.show()
    return Y[-1]

if __name__ == "__main__":
    y = descent(np.array([50, 50]))
    print(y)  

把迭代數和對應的函式值繪製出來以檢視迭代效果:
image.png-15.7kB

問答時間

Q:無法收斂到某個足夠小的函式值,最後報錯: overflow ...
A:步長設定太大,步子大了,容易跨過最低點,導致函式值在最低點上下震盪或發散,如圖:
收斂情況
可以人為設定迭代次數(而不是通過閾值控制是否繼續迭代),然後觀察函式值是否收斂:
步長的選擇

Q:如何選擇合適的步長
A:步長太大會導致函式值不收斂,步長太小又浪費效能,可以通過繪製如上面的迭代次數和函式值關係圖,剛才結果後調整步長,儘量選擇滿足需求的最大步長。達爺在他的網課中給出的建議是:按照這樣的序列試驗步長:..., 0.001, 0.003, 0.01, 0.03, 0.1, 0.3, 1, ...。通過演算法自動預測步長十分複雜,非二叔所能為。

Q:何時停止迭代?
A:可設定一個閾值,當梯度的模長小於這個閾值時停止迭代(當函式接近極值時,梯度接近0)。也可以人為通過剛才迭代次數和函式值影象設定迭代次數。

Q:是否還有其他迭代法?
A:還有牛頓法和擬牛頓法,和梯度下降法的區別是牛頓法不是沿著梯度負方向下降的,而是另一套演算法得出的方向,下降速度更快。

Q:迭代法是否一定會找到函式值域內的最小值?
A:不是,如果函式不是一個凸函式,那麼迭代法可能會找到一個區域性最小值或鞍點值。

Q:函式最大值怎麼找
A:給函式取個負號然後找最小值,或者沿著梯度方向前進而不是負梯度方向前進