1. 程式人生 > 其它 >機器學習系列:LightGBM 視覺化調參

機器學習系列:LightGBM 視覺化調參

大家好,在100天搞定機器學習|Day63 徹底掌握 LightGBM一文中,我介紹了LightGBM 的模型原理和一個極簡例項。最近我發現Huggingface與Streamlit好像更配,所以就開發了一個簡易的 LightGBM 視覺化調參的小工具,旨在讓大家可以更深入地理解 LightGBM

網址:
https://huggingface.co/spaces/beihai/LightGBM-parameter-tuning

我只隨便放了幾個引數,調整這些引數可以實時看到模型評估指標的變化。程式碼我也放到文章中了,大家有好的優化思路可以留言。下面就詳細介紹一下實現過程:

LightGBM 的引數

在完成模型構建之後,必須對模型的效果進行評估,根據評估結果來繼續調整模型的引數、特徵或者演算法,以達到滿意的結果。

LightGBM,有核心引數,學習控制引數,IO引數,目標引數,度量引數,網路引數,GPU引數,模型引數,這裡我常修改的便是核心引數,學習控制引數,度量引數等。

Control Parameters 含義 用法
max_depth 樹的最大深度 當模型過擬合時,可以考慮首先降低 max_depth
min_data_in_leaf 葉子可能具有的最小記錄數 預設20,過擬合時用
feature_fraction 例如 為0.8時,意味著在每次迭代中隨機選擇80%的引數來建樹 boosting 為 random forest 時用
bagging_fraction 每次迭代時用的資料比例 用於加快訓練速度和減小過擬合
early_stopping_round 如果一次驗證資料的一個度量在最近的early_stopping_round 回合中沒有提高,模型將停止訓練 加速分析,減少過多迭代
lambda 指定正則化 0~1
min_gain_to_split 描述分裂的最小 gain 控制樹的有用的分裂
max_cat_group 在 group 邊界上找到分割點 當類別數量很多時,找分割點很容易過擬合時

CoreParameters 含義 用法
Task 資料的用途 選擇 train 或者 predict
application 模型的用途 選擇 regression: 迴歸時,binary: 二分類時,multiclass: 多分類時
boosting 要用的演算法 gbdt, rf: random forest, dart: Dropouts meet Multiple Additive Regression Trees, goss: Gradient-based One-Side Sampling
num_boost_round 迭代次數 通常 100+
learning_rate 如果一次驗證資料的一個度量在最近的 early_stopping_round 回合中沒有提高,模型將停止訓練 常用 0.1, 0.001, 0.003…
num_leaves 預設 31
device cpu 或者 gpu
metric mae: mean absolute error , mse: mean squared error , binary_logloss: loss for binary classification , multi_logloss: loss for multi classification

Faster Speed better accuracy over-fitting
將 max_bin 設定小一些 用較大的 max_bin max_bin 小一些
num_leaves 大一些 num_leaves 小一些
用 feature_fraction 來做 sub-sampling 用 feature_fraction
用 bagging_fraction 和 bagging_freq 設定 bagging_fraction 和 bagging_freq
training data 多一些 training data 多一些
用 save_binary 來加速資料載入 直接用 categorical feature 用 gmin_data_in_leaf 和 min_sum_hessian_in_leaf
用 parallel learning 用 dart 用 lambda_l1, lambda_l2 ,min_gain_to_split 做正則化
num_iterations 大一些,learning_rate 小一些 用 max_depth 控制樹的深度

模型評估指標

以分類模型為例,常見的模型評估指標有一下幾種:

混淆矩陣
混淆矩陣是能夠比較全面的反映模型的效能,從混淆矩陣能夠衍生出很多的指標來。

ROC曲線
ROC曲線,全稱The Receiver Operating Characteristic Curve,譯為受試者操作特性曲線。這是一條以不同閾值 下的假正率FPR為橫座標,不同閾值下的召回率Recall為縱座標的曲線。讓我們衡量模型在儘量捕捉少數類的時候,誤傷多數類的情況如何變化的。

AUC
AUC(Area Under the ROC Curve)指標是在二分類問題中,模型評估階段常被用作最重要的評估指標來衡量模型的穩定性。ROC曲線下的面積稱為AUC面積,AUC面積越大說明ROC曲線越靠近左上角,模型越優;

Streamlit 實現

Streamlit我就不再多做介紹了,老讀者應該都特別熟悉了。就再列一下之前開發的幾個小東西:

核心程式碼如下,完整程式碼我放到Github,歡迎大家給個Star

https://github.com/tjxj/visual-parameter-tuning-with-streamlit

from definitions import *

st.set_option('deprecation.showPyplotGlobalUse', False)
st.sidebar.subheader("請選擇模型引數:sunglasses:")

# 載入資料
breast_cancer = load_breast_cancer()
data = breast_cancer.data
target = breast_cancer.target

# 劃分訓練資料和測試資料
X_train, X_test, y_train, y_test = train_test_split(data, target, test_size=0.2)

# 轉換為Dataset資料格式
lgb_train = lgb.Dataset(X_train, y_train)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)

# 模型訓練
params = {'num_leaves': num_leaves, 'max_depth': max_depth,
            'min_data_in_leaf': min_data_in_leaf, 
            'feature_fraction': feature_fraction,
            'min_data_per_group': min_data_per_group, 
            'max_cat_threshold': max_cat_threshold,
            'learning_rate':learning_rate,'num_leaves':num_leaves,
            'max_bin':max_bin,'num_iterations':num_iterations
            }

gbm = lgb.train(params, lgb_train, num_boost_round=2000, valid_sets=lgb_eval, early_stopping_rounds=500)
lgb_eval = lgb.Dataset(X_test, y_test, reference=lgb_train)  
probs = gbm.predict(X_test, num_iteration=gbm.best_iteration)  # 輸出的是概率結果  

fpr, tpr, thresholds = roc_curve(y_test, probs)
st.write('------------------------------------')
st.write('Confusion Matrix:')
st.write(confusion_matrix(y_test, np.where(probs > 0.5, 1, 0)))

st.write('------------------------------------')
st.write('Classification Report:')
report = classification_report(y_test, np.where(probs > 0.5, 1, 0), output_dict=True)
report_matrix = pd.DataFrame(report).transpose()
st.dataframe(report_matrix)

st.write('------------------------------------')
st.write('ROC:')

plot_roc(fpr, tpr)

上傳Huggingface

Huggingface 前一篇文章(騰訊的這個演算法,我搬到了網上,隨便玩!)我已經介紹過了,這裡就順便再講一下步驟吧。

step1:註冊Huggingface賬號

step2:建立Space,SDK記得選擇Streamlit

step3:克隆新建的space程式碼,然後將改好的程式碼push上去

git lfs install 
git add .
git commit -m "commit from $beihai"
git push

push的時候會讓輸入使用者名稱(就是你的註冊郵箱)和密碼,解決git總輸入使用者名稱和密碼的問題:git config --global credential.helper store

push完成就大功告成了,回到你的space頁對應專案,就可以看到效果了。