多元線性迴歸實驗學習筆記
阿新 • • 發佈:2022-03-22
先貼個程式碼,有空再寫
from matplotlib import projections import numpy as np import matplotlib.pyplot as plt from matplotlib import cm from mpl_toolkits.mplot3d import axes3d # 讀入 train = np.loadtxt('data2.csv',delimiter=',',dtype = 'int') train_x = train[:,0] train_y = train[:,1] train_z = train[:,2] # 生成原始變數函式 ax=plt.subplot(111,projection='3d') ax.scatter(train_x, train_y, train_z ,c="r") plt.show() # 隨機生成 theata theata0 = np.random.rand() theata1 = np.random.rand() theata2 = np.random.rand() # 預測函式 f(x1,x2) = ax1 + bx2 + c def f(x,y) : return theata0 + theata1 * x + theata2 * y # 目標函式 def E(x,y,z): return 0.5 * np.sum((z - f(x,y)) ** 2) # 標準化函式 def standardize(x): mu = x.mean() sigma = x.std() return (x - mu) / sigma train_x_std = standardize(train_x) train_y_std = standardize(train_y) train_z_std = standardize(train_z) # 生成標準化後的函式 ax=plt.subplot(111,projection='3d') ax.scatter(train_x_std, train_y_std, train_z_std ,c="r") plt.show() ETA = 1e-3 # 學習率 diff = 1 # 誤差大小 count = 0 # 迭代次數 cnt = [] # 次數列表 errs = [] # 誤差列表 error = E(train_x_std,train_y_std,train_y_std) while diff > 1e-2: #更新結果儲存到臨時變數 tmp_theata0 = theata0 - ETA * np.sum(f(train_x_std,train_y_std) - train_z_std) tmp_theata1 = theata1 - ETA * np.sum((f(train_x_std,train_y_std) - train_z_std) * train_x_std) tmp_theata2 = theata2 - ETA * np.sum((f(train_x_std,train_y_std) - train_z_std) * train_y_std) theata0 = tmp_theata0 theata1 = tmp_theata1 theata2 = tmp_theata2 current_error = E(train_x_std,train_y_std,train_z_std) diff = error - current_error error = current_error cnt.append(count) errs.append(current_error) count += 1 log = '第 {} 次 : theta0 = {:.3f}, theta1 = {:.3f},theta2 = {:.3f}, 差值 = {:.4f}' print(log.format(count, theata0, theata1, theata2, diff)) ax = plt.subplot(111,projection = '3d') ax.scatter(train_x_std,train_y_std,train_z_std,c = 'r') x = np.arange(-3,3,0.1) y = np.arange(-3,3,0.1) x,y = np.meshgrid(x,y) z = f(x,y) surf = ax.plot_surface(x, y, z, cmap=cm.Blues,linewidth=1, antialiased=False) plt.show() plt.plot(cnt,errs) plt.show()