簡單線性迴歸(ols)-py
阿新 • • 發佈:2018-11-05
簡單的線性迴歸,主要是sklearn庫的學習以及程式碼的實現
http://scikit-learn.org/stable/index.html
# -*- coding: utf-8 -*- """ Created on Thu Nov 1 16:51:59 2018 @author: wp """ import matplotlib.pyplot as plt import numpy as np from sklearn import datasets, linear_model from sklearn.metrics import mean_squared_error, r2_score diabetes = datasets.load_diabetes() diabetes_X = diabetes.data[:, np.newaxis, 2] diabetes_X_train = diabetes_X[:-20] diabetes_X_test = diabetes_X[-20:] diabetes_y_train = diabetes.target[:-20] diabetes_y_test = diabetes.target[-20:] #建立模型 regr = linear_model.LinearRegression() regr.fit(diabetes_X_train, diabetes_y_train) #預測 diabetes_y_pred = regr.predict(diabetes_X_test) #輸出結果 print('Coefficients: \n', regr.coef_) #誤差平方和 print("Mean squared error: %.2f" % mean_squared_error(diabetes_y_test, diabetes_y_pred)) #R^2 print('Variance score: %.2f' % r2_score(diabetes_y_test, diabetes_y_pred)) # Plot outputs plt.scatter(diabetes_X_test, diabetes_y_test, color='black') plt.plot(diabetes_X_test, diabetes_y_pred, color='blue', linewidth=3) plt.xticks(()) #去除座標軸顯示 plt.yticks(()) plt.show()