1. 程式人生 > >python xgboost踩坑實錄

python xgboost踩坑實錄

python xgboost踩坑實錄

前言

在python下執行xgboost有許多要注意的地方。
筆者在載入模型及載入資料的時候都踩了坑,為了避免再度踩坑,所以將解法記錄於此。

載入模型

如果儲存模型時是使用:

model.save_model('0001.model')

那麼在載入模型時不能單純地使用xgb.Booster

import xgboost as xgb
model = xgb.Booster('0001.model')

而是必須要先初始化模型(須傳入引數)後再載入.model

檔。

model = xgb.Booster({'nthread': 4})  # init model
model.load_model('0001.model')  # load data

載入資料

xgboost模型的資料必須為xgboost.DMatrix格式。
如果丟numpy.ndarray的格式進去會報錯。

參考Predicting unknown class from dumped Model

import xgboost as xgb
data = xgb.DMatrix(data) #convert from numpy array to xgboost.DMatrix
prediction = model.predict(data)

參考連結

Predicting unknown class from dumped Model
How to save & load xgboost model?
XGBoost - Python Package Introduction