1. 程式人生 > >如何儲存訓練好的機器學習模型

如何儲存訓練好的機器學習模型

儲存訓練好的機器學習模型
        當我們訓練好一個model後,下次如果還想用這個model,我們就需要把這個model儲存下來,下次直接匯入就好了,不然每次都跑一遍,訓練時間短還好,要是一次跑好幾天的那怕是要天荒地老了。。sklearn官網提供了兩種儲存model的方法:官網地址

1.使用python自帶的pickle

from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
import pickle

#方法一:python自帶的pickle
(X,y) = datasets.load_iris(return_X_y=True)
rfc = RandomForestClassifier(n_estimators=100,max_depth=100)
rfc.fit(X,y)
print(rfc.predict(X[0:1,:]))
#save model
f = open('saved_model/rfc.pickle','wb')
pickle.dump(rfc,f)
f.close()
#load model
f = open('saved_model/rfc.pickle','rb')
rfc1 = pickle.load(f)
f.close()
print(rfc1.predict(X[0:1,:]))

2.使用sklearn中的模組joblib
使用joblib模組更加的簡單了,核心程式碼就兩行

from sklearn.ensemble import RandomForestClassifier
from sklearn import datasets
from sklearn.externals import joblib
#方法二:使用sklearn中的模組joblib
(X,y) = datasets.load_iris(return_X_y=True)
rfc = RandomForestClassifier(n_estimators=100,max_depth=100)
rfc.fit(X,y)
print(rfc.predict(X[0:1,:]))
#save model
joblib.dump(rfc, 'saved_model/rfc.pkl')
#load model
rfc2 = joblib.load('saved_model/rfc.pkl')
print(rfc2.predict(X[0:1,:]))

這兩個方法都可以,但是更推薦用第二種,即joblib,因為根據官網介紹,速度更快。
---------------------
原文:https://blog.csdn.net/u012328159/article/details/79255805