1. 程式人生 > 其它 >sklearn.model_selection.GridSearchCV

sklearn.model_selection.GridSearchCV

目錄

\(sklearn\) 官網中的 GridSearchCV

機器學習模型中,需要人工選擇的引數稱為超引數

\(GridSearchCV\) 可以拆分為兩部分\(GridSearch、CV\),即網格搜尋、交叉驗證\(GridSearch\) 是一種調參手段,窮舉搜尋,即在所有候選的引數中,通過迴圈遍歷每一種可能,選擇最好的引數。



1. GridSearchCV 簡介

\(GridSearchCV\) 的意義是自動調參,將引數輸進去,給出最優化結果和引數。缺點是適用小資料集(小於 \(10000\)

),資料量比較大的時候,使用快速調優方法——座標下降。它其實是一種貪心演算法:拿對當前模型影響最大的引數調優,直到最優化;再拿下一個影響最大的引數調優,如此下去,直到所有的引數調整完畢。該方法的缺點是調參結果可能是區域性最優不是全域性最優。但是省時省力,巨大的優勢面前,可以一試。



2. GridSearchCV 引數

class sklearn.model_selection.GridSearchCV(estimator, param_grid, *, scoring=None, n_jobs=None, refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs', error_score=nan, return_train_score=False)

引數:

  • estimator選擇使用的模型,並且傳入除需要確定最佳的引數外的其他引數。每個分類器都需要一個 \(scoring\) 引數,或者 \(score\) 方法,如:estimator = RandomForestClassifier(min_sample_split=100, min_samples_leaf=20, max_depth=8, max_features='sqrt', random_state=10)
  • param_grid調參的引數列表或字典。\(SVM\) 分類模型中:param_grid = {"C":[0.1, 1, 10], "gamma": [0.1, 0.2, 0.3]}
    ,這樣就有 \(9\) 中超引數組合來進行網格搜尋,選擇一個擬合分數最好的超平面係數。
  • scoring=None模型評價標準\(scoring = None\) 表示使用 \(estimator\) 的誤差估計函式。如:scoring = "roc_auc"
  • n_jobs=None並行數。\(n\_jobs=None\) 表示預設值取 \(1\)\(n\_jobs=-1\) 表示與 \(CPU\) 核數一致。
  • refit=True:預設 \(True\),表示程式以交叉驗證得到的最佳引數後,用最佳引數再次 \(fit\) 一遍全部資料集。作為最終效能評估的最佳模型引數。
  • cv = None交叉驗證引數\(cv = None\) 表示預設值取 \(3\),使用三折交叉驗證。


3. 常用方法、屬性

grid.fit():執行網格搜尋。

best_params_:描述已取得最佳結果的引數組合。

best_score_:提供優化過程期間觀察到的最好的評分。

cv_results_:不同引數下模型的交叉驗證結果。