1. 程式人生 > >基於sciket-learn實現多項式迴歸

基於sciket-learn實現多項式迴歸

多項式迴歸在思想上和線性迴歸是一致的,都使用一條線去擬合樣本值,進入用得出的模型去進行預測,在樣本特徵呈現出線性特性時,我們可以用線性迴歸去做預測,但是在樣本特徵很複雜的時候,線性迴歸往往會呈現出欠擬合的狀態,這時就需要多項式迴歸。

先來看一個小例子,給定一條二次曲線y=2x^2 + 2x,生成帶噪聲的100個樣本點,繪製出影象 ,是我們熟悉的二次方程。

x = np.random.uniform(-3, 3, size = 100)
X = x.reshape(-1, 1)

y = 2 * x**2 + x + 2 + np.random.normal(0, 1, size = 100)

plt.scatter(x, y)
plt.show()

然後我們用線性迴歸來擬合這條曲線,看看會出現什麼情況

from sklearn.linear_model import LinearRegression

lin_reg = LinearRegression()
lin_reg.fit(X, y)

y_predict = lin_reg.predict(X)

plt.scatter(x, y)
plt.plot(x, y_predict, color='r')
plt.show()

很明顯我們的預測函式沒有很好的擬合這些樣本點,當遇到這種情況時,我們不妨在增加一個特徵

X2 = np.hstack([X, X**2])

然後同樣呼叫sciket-learn為我們封裝好線性迴歸構造器,接著繪製出影象

lin_reg2 = LinearRegression()
lin_reg2.fit(X2, y)
y_predict2 = lin_reg2.predict(X2)

plt.scatter(x, y)
plt.plot(np.sort(x), y_predict2[np.argsort(x)], color='r')
plt.show()

這時,便可以看到,擬合程度已經比較好了。

sciket-learn中為我們提供了PolynomialFeatures來確定特徵的維度。

from sklearn.preprocessing import PolynomialFeatures

poly = PolynomialFeatures(degree=2)
poly.fit(X)
X3 = poly.transform(X)

lin_reg3 = LinearRegression()
lin_reg3.fit(X3, y)
y_predict3 = lin_reg3.predict(X3)

plt.scatter(x, y)
plt.plot(np.sort(x), y_predict3[np.argsort(x)], color='r')
plt.show()

可以看出得到的影象和上面的影象是一致的,這裡有興趣的朋友可以改變degree引數的值,看看會發生什麼樣的變化。

完整程式碼

import numpy as np
import matplotlib.pyplot as plt

x = np.random.uniform(-3, 3, size = 100)
X = x.reshape(-1, 1)
y = 2 * x**2 + x + 2 + np.random.normal(0, 1, size = 100)

plt.scatter(x, y)
plt.show()

from sklearn.linear_model import LinearRegression

lin_reg = LinearRegression()
lin_reg.fit(X, y)

y_predict = lin_reg.predict(X)

plt.scatter(x, y)
plt.plot(x, y_predict, color='r')
plt.show()

# 解決方案 新增一個特徵
(X**2).shape
X2 = np.hstack([X, X**2])
X2.shape

plt.scatter(x, y)
plt.plot(np.sort(x), y_predict2[np.argsort(x)], color='r')
plt.show()

lin_reg2.coef_
lin_reg2.intercept_

from sklearn.preprocessing import PolynomialFeatures

poly = PolynomialFeatures(degree=2)
poly.fit(X)
X3 = poly.transform(X)

lin_reg3 = LinearRegression()
lin_reg3.fit(X3, y)
y_predict3 = lin_reg3.predict(X3)

plt.scatter(x, y)
plt.plot(np.sort(x), y_predict3[np.argsort(x)], color='r')
plt.show()