1. 程式人生 > 實用技巧 >sklearn.model_selection.RandomizedSearchCV隨機搜尋超引數

sklearn.model_selection.RandomizedSearchCV隨機搜尋超引數

GridSearchCV可以保證在指定的引數範圍內找到精度最高的引數,但是這也是網格搜尋的缺陷所在,它要求遍歷所有可能引數的組合,在面對大資料集和多引數的情況下,非常耗時。這也是我通常不會使用GridSearchCV的原因,一般會採用後一種RandomizedSearchCV隨機引數搜尋的方法

RandomizedSearchCV的使用方法其實是和GridSearchCV一致的,但它以隨機在引數空間中取樣的方式代替了GridSearchCV對於引數的網格搜尋,在對於有連續變數的引數時,RandomizedSearchCV會將其當作一個分佈進行取樣這是網格搜尋做不到的,它的搜尋能力取決於設定的n_iter引數

函式用法:

class sklearn.model_selection.RandomizedSearchCV(estimator, param_distributions, *, n_iter=10, 
scoring=None, n_jobs=None, iid='deprecated', refit=True, cv=None, verbose=0, pre_dispatch='2*n_jobs',
random_state=None, error_score=nan, return_train_score=False)

引數詳解:

estimator:估計器

param_distributions

字典或字典列表:引數字典,key是引數名,values是引數範圍

n_iterint,預設= 10:抽取樣本是訓練次數

更多引數參考:https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html

RandomSearchCV是如何"隨機搜尋"的

考察其原始碼,其搜尋策略如下:
(a)對於搜尋範圍是distribution的超引數,根據給定的distribution隨機取樣;
(b)對於搜尋範圍是list的超引數,在給定的list中等概率取樣;
(c)對a、b兩步中得到的n_iter組取樣結果,進行遍歷。


(補充)如果給定的搜尋範圍均為list,則不放回抽樣n_iter次。

import numpy as np
from scipy.stats import randint as sp_randint
from sklearn.model_selection import RandomizedSearchCV
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier

# 載入資料
digits = load_digits()
X, y = digits.data, digits.target

# 建立一個分類器或者回歸器
clf = RandomForestClassifier(n_estimators=20)

# 給定引數搜尋範圍:list or distribution
param_dist = {"max_depth": [3, None],                     #給定list
              "max_features": sp_randint(1, 11),          #給定distribution
              "min_samples_split": sp_randint(2, 11),     #給定distribution
              "bootstrap": [True, False],                 #給定list
              "criterion": ["gini", "entropy"]}           #給定list

# 用RandomSearch+CV選取超引數
n_iter_search = 20
random_search = RandomizedSearchCV(clf, param_distributions=param_dist,
                                   n_iter=n_iter_search, cv=5, iid=False)
clf=random_search.fit(X, y)
clf.best_params_ 
{'bootstrap': False,
 'criterion': 'entropy',
 'max_depth': None,
 'max_features': 9,
 'min_samples_split': 8}