多項式迴歸案例(附資料集下載地址)
阿新 • • 發佈:2019-02-12
當我們完成了資料的預處理環節後,我們可以先對資料進行視覺化,根據影象可以初步的判斷我們的模型應該是怎麼樣的,如何更好地擬合,請看下面的例子:
資料集:
Position | Level | Salary |
---|---|---|
Business Analyst | 1 | 45000 |
Junior Consultant | 2 | 50000 |
Senior Consultant | 3 | 60000 |
Manager | 4 | 80000 |
Country Manager | 5 | 110000 |
Region Manager | 6 | 150000 |
Partner | 7 | 200000 |
Senior Partner | 8 | 300000 |
C-level | 9 | 500000 |
CEO | 10 | 1000000 |
#首先還是匯入必要的庫
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.preprocessing import PolynomialFeatures
import statsmodels.formula.api as sm
#載入資料,由於資料樣本少,不做分片
dataset = pd.read_csv('Position_Salaries.csv')
# X = dataset.iloc[:,1].values #shape -- (10,)是個向量,可是我們需要進入訓練的變數需要是一個矩陣
X = dataset.iloc[:,1:2].values #shape -- (10, 1) 矩陣
Y = dataset.iloc[:,2].values #shape -- (10,)是個向量
#視覺化資料
plt.figure(figsize=(10,14))
plt.subplot(211)
plt.scatter(X,Y)
plt.savefig('scatter.png')
資料描點完成了,根據影象容易得出一元方程擬合效果不會太好。
實際也正如我們料想的一致
觀察圖一,我們可以選用多項式來解決這個迴歸問題
- 對自變數進行矩陣轉化,轉化為有不同次數的矩陣
# 對X進行多次項處理
Poly = PolynomialFeatures(degree=2) #引數degree是限定生成的X矩陣的最高次數
X_poly = Poly.fit_transform(X)
輸出X_poly結果如下:
#這個操作自動添加了常數項的係數(第一列)第二列是一次項,第二列是二次項
[[ 1. 1. 1.]
[ 1. 2. 4.]
[ 1. 3. 9.]
[ 1. 4. 16.]
[ 1. 5. 25.]
[ 1. 6. 36.]
[ 1. 7. 49.]
[ 1. 8. 64.]
[ 1. 9. 81.]
[ 1. 10. 100.]]
poly_reg = sm.OLS(endog=Y,exog=X_poly).fit()
Y_pre2 = poly_reg.predict(X_poly)
plt.plot(X,poly_reg.predict(X_poly) ,color = 'black',label = 'poly_degree=2')
plt.legend()
plt.savefig('lin&poly.png')
顯然擬合也不是特別好。可以通過提高多項式次數來達到更好的擬合度,小心過度擬合問題
Poly = PolynomialFeatures(degree=3) #引數degree是限定生成的X矩陣的最高次數
X_poly = Poly.fit_transform(X)
poly_reg = sm.OLS(endog=Y,exog=X_poly).fit()
Y_pre2 = poly_reg.predict(X_poly)
plt.plot(X,poly_reg.predict(X_poly) ,color = 'green',label = 'poly_degree=3')
plt.legend() #提高到3次基本擬合效果很好了
plt.savefig('lin&poly3.png')
提高到3次基本擬合效果很好了
最後對影象進行優化處理,上述影象由於自變數的間距相對較大,影象不夠平滑。我們可以有如下操作:
X_grid = np.arange(min(X),max(X),0.1)
X_grid = X_grid.reshape(len(X_grid),1)
plt.plot(X_grid, poly_reg.predict(Poly.fit_transform(X_grid)) ,color = 'green',label = 'poly_degree=3')
plt.legend()
迴歸器資訊: