1. 程式人生 > 其它 >PyTorch深度學習實踐第三講——梯度下降

PyTorch深度學習實踐第三講——梯度下降

技術標籤:深度學習深度學習python

梯度下降

任務:依然是用三個資料點,擬合一個線性模型,模型的引數的確定採用梯度下降法,本問題中,損失函式是MSE,梯度很容易求解grad = \frac{1}{n}\sum x(wx-y)),程式碼比較簡單,需要注意的是梯度的計算需要初始w以及設定學習率alpha,深度學習中對於陷入區域性最優點有一定的解決辦法,但是要注意鞍點的處理(梯度取0的點)。

def forward(x):
    return w*x
def cost(xs, ys):
    cost = 0
    for x_val, y_val in zip(xs, ys):
        y_pred = forward(x_val)
        cost += (y_val-y_pred) ** 2
    return cost / len(xs)
def gradient(xs, ys):
    grad = 0
    for x_val, y_val in zip(xs, ys):
        grad += 2 * (w * x_val - y_val) * x_val
    return grad / len(xs)
w = 1.0
alpha = 0.04
list_cost = []
list_epoch = []
for epoch in range(100):
    cost_val = cost(x_data, y_data)
    grad_val = gradient(x_data, y_data)
    w -= alpha * grad_val
    print("epoch: {0}, w = {1}, cost = {2}".format(epoch, w, cost_val))
    list_epoch.append(epoch)
    list_cost.append(cost_val)
print("Predict : x = 4, y = ", forward(4))
plt.xlabel("epoch")
plt.ylabel("cost")
plt.plot(list_epoch, list_cost)
plt.show()

梯度下降演算法將所有樣本作為一個整體來計算,由於樣本與樣本之間沒有依賴關係,所以可以使用並行的方式,這樣演算法的速度能夠很快但是效能難以保證;而隨機梯度下降法(SGD),每次僅拿出一個樣本來計算梯度從而更新w值,演算法效能可能更優,但是在更新w時,樣本之間是有前後依賴關係的(因為每次更新w都是通過一個樣本),演算法的速度較慢。為了折中,深度學習中採用mini-batch的方法(將若干個樣本分成一組來更新w)。

import numpy as np
import matplotlib.pyplot as plt
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6,0]
def forward(x):
    return w * x
def loss(x, y):
    y_pred = forward(x)
    return (y - y_pred) ** 2
def gradient(x, y):
    return (w * x - y) * x
w = 1.0
alpha = 0.04
list_cost = []
list_epoch = []
for epoch in range(100):
    grad = 0
    for x_val, y_val in zip(x_data, y_data):
        loss_val = loss(x_val, y_val)
        grad = gradient(x_val, y_val)
        w -= alpha * grad
        print("sample x = {0}, y = {1}, grad = {2}".format(x_val, y_val, grad))
    print("Epoch = {0}, w = {1}, loss = {2}".format(epoch, w, loss_val))
    list_epoch.append(epoch)
    list_cost.append(loss_val)
print("Predict : x = 4, y = ", forward(4))
plt.xlabel("epoch")
plt.ylabel("cost")
plt.plot(list_epoch, list_cost)
plt.show()