1. 程式人生 > 其它 >多元線性迴歸實驗學習筆記

多元線性迴歸實驗學習筆記

先貼個程式碼,有空再寫

from matplotlib import projections
import numpy as np 
import matplotlib.pyplot as plt
from matplotlib import cm
from mpl_toolkits.mplot3d import axes3d

# 讀入
train = np.loadtxt('data2.csv',delimiter=',',dtype = 'int')
train_x = train[:,0]
train_y = train[:,1]
train_z = train[:,2]

# 生成原始變數函式
ax=plt.subplot(111,projection='3d') 
ax.scatter(train_x, train_y, train_z ,c="r")
plt.show()

# 隨機生成 theata
theata0 = np.random.rand()
theata1 = np.random.rand()
theata2 = np.random.rand()

# 預測函式 f(x1,x2) = ax1 + bx2 + c
def f(x,y) :
    return theata0 + theata1 * x + theata2 * y

# 目標函式
def E(x,y,z):
    return 0.5 * np.sum((z - f(x,y)) ** 2)

# 標準化函式 
def standardize(x):
    mu = x.mean()
    sigma = x.std()
    return (x - mu) / sigma

train_x_std = standardize(train_x)
train_y_std = standardize(train_y)
train_z_std = standardize(train_z)

# 生成標準化後的函式
ax=plt.subplot(111,projection='3d') 
ax.scatter(train_x_std, train_y_std, train_z_std ,c="r")
plt.show()

ETA = 1e-3 # 學習率
diff = 1 # 誤差大小
count = 0 # 迭代次數
cnt = [] # 次數列表
errs = [] # 誤差列表

error = E(train_x_std,train_y_std,train_y_std)
while diff > 1e-2:
    #更新結果儲存到臨時變數
    tmp_theata0 = theata0 - ETA * np.sum(f(train_x_std,train_y_std) - train_z_std)
    tmp_theata1 = theata1 - ETA * np.sum((f(train_x_std,train_y_std) - train_z_std) * train_x_std)
    tmp_theata2 = theata2 - ETA * np.sum((f(train_x_std,train_y_std) - train_z_std) * train_y_std)

    theata0 = tmp_theata0
    theata1 = tmp_theata1
    theata2 = tmp_theata2
    
    current_error = E(train_x_std,train_y_std,train_z_std)

    diff = error - current_error
    error = current_error

    cnt.append(count)
    errs.append(current_error)
    count += 1

    log = '第 {} 次 : theta0 = {:.3f}, theta1 = {:.3f},theta2 = {:.3f}, 差值 = {:.4f}'
    print(log.format(count, theata0, theata1, theata2, diff))



ax = plt.subplot(111,projection = '3d')
ax.scatter(train_x_std,train_y_std,train_z_std,c = 'r')

x = np.arange(-3,3,0.1)
y = np.arange(-3,3,0.1)
x,y = np.meshgrid(x,y)
z = f(x,y)

surf = ax.plot_surface(x, y, z, cmap=cm.Blues,linewidth=1, antialiased=False)

plt.show()

plt.plot(cnt,errs)
plt.show()