sklearn.model_selection.RandomizedSearchCV隨機搜尋超引數
阿新 • • 發佈:2020-09-08
GridSearchCV可以保證在指定的引數範圍內找到精度最高的引數,但是這也是網格搜尋的缺陷所在,它要求遍歷所有可能引數的組合,在面對大資料集和多引數的情況下,非常耗時。這也是我通常不會使用GridSearchCV的原因,一般會採用後一種RandomizedSearchCV隨機引數搜尋的方法
RandomizedSearchCV的使用方法其實是和GridSearchCV一致的,但它以隨機在引數空間中取樣的方式代替了GridSearchCV對於引數的網格搜尋,在對於有連續變數的引數時,RandomizedSearchCV會將其當作一個分佈進行取樣這是網格搜尋做不到的,它的搜尋能力取決於設定的n_iter引數
函式用法:
class sklearn.model_selection.RandomizedSearchCV(estimator, param_distributions, *, n_iter=10,
scoring=None, n_jobs=None, iid='deprecated', refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
random_state=None, error_score=nan, return_train_score=False)
引數詳解:
estimator:估計器
param_distributions 字典或字典列表:引數字典,key是引數名,values是引數範圍
n_iterint,預設= 10:抽取樣本是訓練次數
更多引數參考:https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html
RandomSearchCV是如何"隨機搜尋"的
考察其原始碼,其搜尋策略如下:
(a)對於搜尋範圍是distribution的超引數,根據給定的distribution隨機取樣;
(b)對於搜尋範圍是list的超引數,在給定的list中等概率取樣;
(c)對a、b兩步中得到的n_iter組取樣結果,進行遍歷。
(補充)如果給定的搜尋範圍均為list,則不放回抽樣n_iter次。
import numpy as np from scipy.stats import randint as sp_randint from sklearn.model_selection import RandomizedSearchCV from sklearn.datasets import load_digits from sklearn.ensemble import RandomForestClassifier # 載入資料 digits = load_digits() X, y = digits.data, digits.target # 建立一個分類器或者回歸器 clf = RandomForestClassifier(n_estimators=20) # 給定引數搜尋範圍:list or distribution param_dist = {"max_depth": [3, None], #給定list "max_features": sp_randint(1, 11), #給定distribution "min_samples_split": sp_randint(2, 11), #給定distribution "bootstrap": [True, False], #給定list "criterion": ["gini", "entropy"]} #給定list # 用RandomSearch+CV選取超引數 n_iter_search = 20 random_search = RandomizedSearchCV(clf, param_distributions=param_dist, n_iter=n_iter_search, cv=5, iid=False) clf=random_search.fit(X, y) clf.best_params_
{'bootstrap': False, 'criterion': 'entropy', 'max_depth': None, 'max_features': 9, 'min_samples_split': 8}