機器學習——線性迴歸
阿新 • • 發佈:2018-12-13
1 def test_fj(): 2 X = np.array([[500, 3, 0.3], [1000, 1, 0.6], [750, 2, 0.3], [600, 5, 0.2], [1200, 1, 0.6]], dtype=float) 3 Y = np.array([10000, 9000, 8000, 12000, 8500], dtype=float) 4 5 x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.25) 6 print(x_train, x_test) 7 print('===================================================') 8 print(y_train, y_test) 9 10 std_x = StandardScaler() 11 x_train = std_x.fit_transform(x_train) 12 x_test = std_x.transform(x_test) 13 14 std_y = StandardScaler() 15 y_train = std_y.fit_transform(y_train.reshape(-1, 1)) 16 y_test = std_y.transform(y_test.reshape(-1, 1))17 18 lr = LinearRegression() 19 lr.fit(x_train, y_train) 20 print(lr.coef_) 21 22 # orign = std_y.inverse_transform(y_test[1]) 23 # print('orign is value:::::',orign) 24 # y_lr_predict = std_y.inverse_transform(lr.predict(np.array([x_test[1]]))) 25 y_lr_predict = std_y.inverse_transform(lr.predict(x_test))26 27 print('房價:', y_lr_predict) 28 print('評分:', r2_score(std_y.inverse_transform(y_test), y_lr_predict)) 29 30 31 def price_predict(): 32 # 資料有三個特徵:距離地鐵距離、附近小學、小區綠化率 33 X = np.array([[500, 3, 0.3], [1000, 1, 0.6], [750, 2, 0.3], [600, 5, 0.2], [1200, 1, 0.6]], dtype=float) 34 # 具有三個特徵的房屋對應的房價 35 Y = np.array([10000, 9000, 8000, 12000, 8500], dtype=float) 36 37 std_x = StandardScaler() 38 x_train = std_x.fit_transform(X) 39 40 std_y = StandardScaler() 41 y_train = std_y.fit_transform(Y.reshape(-1, 1)) 42 # 構建線性預測模型 43 lr = LinearRegression() 44 # 模型在歷史資料上進行訓練,Y.reshape(-1,1)將Y變為二維陣列,fit函式引數要求是二維陣列 45 lr.fit(x_train, y_train.reshape(-1, 1)) 46 # 使用訓練模型預測新房屋價格 47 distance = input('請輸入新房屋距離地鐵的距離:') 48 school = input('請輸入附近小學數量:') 49 green = input('請輸入小區綠化率:') 50 x_predict = std_x.transform(np.array([[distance, school, green]], dtype=float)) 51 print(std_y.inverse_transform(lr.predict(x_predict))) 52 # print(lr.predict(np.array([[distance, school, green]], dtype=float))) 53 # print(lr.predict(np.array([[1300, 3, 0.4]]))) 54 55 56 if __name__ == '__main__': 57 pairplot_analyse() 58 # heatmap_analyse() 59 # bostn_linear() 60 # log_fit() 61 # test_fj() 62 # price_predict() 63 pass