機器學習演算法之自適應線性感知器演算法
阿新 • • 發佈:2019-02-12
自適應線性感知器演算法
原理推導
python實現
import numpy as np
import matplotlib.pyplot as plt
#隨機生成x和y, n=100, m=2
x = np.random.randn(100, 2)
y = x.dot(np.array([[2], [1]])) + 1
#初始話權值w和偏置b
w = np.zeros(x.shape[1]).reshape((x.shape[1], 1))#
b = 0
#設定超參
epoches = 100
alpha = 0.1
#誤差追蹤
error = []
#開始迭代
for e in range(epoches):
y_hat = np.dot(x, w) + b #計算y_hat
err = y_hat - y
error += [err.sum() ** 2 / (2 * len(x))]#計算損失函式
delta_w = x.T.dot(err) / (len(x)) #計算w偏導
delta_b = sum(err) / len(x) #計算b的偏導
#更新w和b
w -= delta_w * alpha
b -= delta_b * alpha
print(w)
print(b)
plt.plot(error)
plt.show()
執行結果
w和b,可以看出與真實值基本一致
誤差追蹤,誤差不斷下降
總結
在算偏導的時候一定不要忘了除以樣本數,在處理資料x時儘量把值域控制的小一點,從而控制delta_w和delta_b的值,以免過大影響學習。