lightgbm,xgboost,gbdt的區別與聯絡學習筆記
阿新 • • 發佈:2018-12-18
1.lightGBM安裝
在anaconda中輸入:pip install lightGBM即可
輸入import lightgbm as lgb做測試
2.lightGBM改進
1.直方圖差加速:直方圖演算法的基本思想是先把連續的浮點特徵值離散化成k個整數,同時構造一個寬度為k的直方圖。在遍歷資料的時候,根據離散化後的值作為索引在直方圖中累積統計量,當遍歷一次資料後,直方圖累積了需要的統計量,然後根據直方圖的離散值,遍歷尋找最優的分割點。記憶體消耗降低,計算上的代價也大幅降低2.leaf-wise:每次從當前所有葉子中,找到分裂增益最大的一個葉子,然後分裂,如此迴圈。因此同Level-wise相比,在分裂次數相同的情況下,Leaf-wise可以降低更多的誤差,得到更好的精度。可能會長出比較深的決策樹,產生過擬合。因此LightGBM在Leaf-wise之上增加了一個最大深度限制,在保證高效率的同時防止過擬合。 3.特徵並行和資料並行:特徵並行的主要思想是在不同機器在不同的特徵集合上分別尋找最優的分割點,然後在機器間同步最優的分割點。資料並行則是讓不同的機器先在本地構造直方圖,然後進行全域性的合併,最後在合併的直方圖上面尋找最優分割點。4.直接支援類別特徵:可以直接輸入類別特徵,不需要額外的0/1 展開,LightGBM 是第一個直接支援類別特徵的 GBDT 工具。
3.lightGBM使用
#讀模型資料 import pandas as pd df=pd.read_csv('all_data.csv') #剔除模型無關變數和完全線性相關的變數 """取樣""" df_1 = df[df['LABEL'] == 1] df_0 = df[df['LABEL'] == 0] df3 = df_0.sample(10000) df_con= pd.concat([df_1, df3], ignore_index=True) #剔除模型無關變數和完全線性相關的變數 select_var = ~df.columns.isin(["PERSONID","LABEL","CREATETIME_time"]) X = df.ix[:, select_var] Y = df.LABEL.values #X=(X-X.mean(axis=0))/X.std(axis=0) X=X.fillna(0) from sklearn.cross_validation import train_test_split train_X, test_X, train_y, test_y = train_test_split(X, Y, train_size = 0.8, random_state = 123) import lightgbm as lgb lgb_train = lgb.Dataset(train_X, train_y, free_raw_data=False) lgb_eval = lgb.Dataset(test_X, test_y, reference=lgb_train,free_raw_data=False) param = { 'task': 'train', 'boosting_type': 'gbdt', 'objective': 'binary', 'metric': {'l2', 'auc'}, 'num_leaves': 40, 'learning_rate': 0.01, 'feature_fraction': 0.8, 'bagging_fraction': 0.8, 'bagging_freq': 5, 'verbose': 0 } param['is_unbalance']='true' param['metric'] = 'auc' bst=lgb.cv(param,lgb_train, num_boost_round=1000, nfold=6, early_stopping_rounds=100) estimators = lgb.train(param,lgb_train,num_boost_round=len(bst['auc-mean'])) print('Start training...') y_pred = estimators.predict(test_X, num_iteration=estimators.best_iteration) from sklearn import metrics print('The roc of prediction is:', metrics.roc_auc_score(test_y, y_pred) ) print('Feature names:', estimators.feature_name()) print('Feature importances:', list(estimators.feature_importance()))