線性迴歸10-模型儲存和載入
阿新 • • 發佈:2021-09-16
1 sklearn模型的儲存和載入API
- from sklearn.externals import joblib
- 儲存:joblib.dump(estimator, 'test.pkl')
- 載入:estimator = joblib.load('test.pkl')
2 線性迴歸的模型儲存載入案例
from sklearn.datasets import load_boston from sklearn.model_selection import train_test_split from sklearn.preprocessing import StandardScaler from sklearn.linear_model import LinearRegression,SGDRegressor,RidgeCV,Ridge from sklearn.metrics import mean_squared_error #from sklearn.externals import joblib import joblib def dump_load() : """ 模型儲存和載入 :return: """ # 1.獲取資料 boston = load_boston() # 2.資料處理 # 2.1 分割資料 x_train, x_test, y_train, y_test = train_test_split(boston.data, boston.target,random_state=22,test_size=0.2) # 3.特徵工程-資料標準化 transfer = StandardScaler() x_train = transfer.fit_transform(x_train) x_test = transfer.fit_transform(x_test) # # 4.機器學習-線性迴歸(梯度下降) # # 4.1 模型訓練 # estimator = Ridge(alpha=1.0) # estimator.fit(x_train, y_train) # # # 4.2 模型儲存 # joblib.dump(estimator,"./data/test1.pkl") # # 4.3 模型載入 estimator=joblib.load("./data/test1.pkl") # 5.模型評估 y_predict = estimator.predict(x_test) print("預測值為:\n", y_predict) print("模型中的係數為:\n", estimator.coef_) print("模型中的偏置為:\n", estimator.intercept_) # 評價指標 均方誤差 error = mean_squared_error(y_test, y_predict) print("均方誤差:\n", error) return None #呼叫函式 if __name__=='__main__': dump_load()
5.4 結果
直接呼叫模型和原本模型中的結果是一樣的
注:保證結果一致,需要引數一樣,同時執行時需要運行當前的程式碼