1. 程式人生 > >sklearn超引數調整方法 [GridSearchCV, RandomizedSearchCV]

sklearn超引數調整方法 [GridSearchCV, RandomizedSearchCV]

模型調整, 假設已經找到了一些潛在的模型,下面是幾種方法用於模型調整

1. 超引數修改

  • 網格搜尋 (grid searh)
    一種方法是手動調整超引數(hyperparameters)。
    GridSearchCV,引數為你想調整的超引數和該超引數的值。
    class sklearn.model_selection.GridSearchCV(estimator, param_grid, scoring=None, fit_params=None, n_jobs=None, iid=’warn’, refit=True, cv=’warn’, verbose=0, pre_dispatch=‘2*n_jobs’, error_score=’raise-deprecating’, return_train_score=’warn’)

    引數:
  • estimator sklearn實現的estimator介面,需要提供score函式,或者有scoring引數
  • param_id: 字典或字典的列表
    [{需要調整的超引數1: 超引數的可能的值}, {{需要調整的超引數2: 超引數的可能的值}}, ], 注意一個字典內為一組超引數的組合方式
  • scoring: 誤差函式
  • cv : 將訓練集分為多少個folds

屬性:

  • best_params

  • cv_results_

from sklearn.model_selection import GridSearchCV
param_grid =
[ {'n_estimators': [3, 10, 30], 'max_features': [2, 4, 6, 8]}, {'bootstrap': [False], 'n_estimators': [3, 10], 'max_features': [2, 3, 4]}, ] forest_reg = RandomForestRegressor() grid_search = GridSearchCV(forest_reg, param_grid, cv=5, scoring='neg_mean_squared_error') grid_search.
fit(housing_prepared, housing_labels)

sklearn 根據param_grid的值,首先會評估 3 × 4 = 12 3 \times 4 = 12 種n_estimators和max_features的組合方式,接下來在會在bootstrap=False的情況下(預設該值為True),評估 2 × 3 = 6 2 \times3 =6 種12種n_estimators和max_features的組合方式,所以最終會有 12 + 6 = 18 12+6=18 種不同的超引數組合方式,
而每一種組合方式要在訓練集上訓練5次, 所以一共要訓練 18 × 5 = 90 18 \times 5 = 90 次,當訓練結束後,你可以通過best_params_獲得最好的組合方式

grid_search.best_params_

out:

{‘max_features’: 8, ‘n_estimators’: 30}

得到最好的模型

grid_search.best_estimator_

out:

RandomForestRegressor(bootstrap=True, criterion=‘mse’, max_depth=None,
max_features=8, max_leaf_nodes=None, min_impurity_decrease=0.0,
min_impurity_split=None, min_samples_leaf=1,
min_samples_split=2, min_weight_fraction_leaf=0.0,
n_estimators=30, n_jobs=1, oob_score=False, random_state=None,
verbose=0, warm_start=False)

如果GridSearchCV初始化時,refit=True(預設的初始化值),在交叉驗證時,一旦發現最好的模型(estimator),將會在整個訓練集上重新訓練,這通常是一個好主意,因為使用更多的資料集會提升模型的效能。

cv_results_:將結果存在一個字典裡, 可以轉化為DataFrame型別,每一行為一種超引數組合方式。

cv = pd.DataFrame(grid_search.cv_results_
cv

# out:



	mean_fit_time	std_fit_time	mean_score_time	std_score_time	param_max_features	param_n_estimators	param_bootstrap	params	split0_test_score	split1_test_score	...	mean_test_score	std_test_score	rank_test_score	split0_train_score	split1_train_score	split2_train_score	split3_train_score	split4_train_score	mean_train_score	std_train_score
0	0.085940	0.025568	0.004391	0.001827	2	3	NaN	{'max_features': 2, 'n_estimators': 3}	-3.827812e+09	-4.092971e+09	...	-4.139394e+09	1.959892e+08	18	-1.108742e+09	-1.076285e+09	-1.151262e+09	-1.127172e+09	-1.114365e+09	-1.115565e+09	2.449443e+07
1	0.237965	0.021394	0.010079	0.000971	2	10	NaN	{'max_features': 2, 'n_estimators': 10}	-2.742609e+09	-3.333789e+09	...	-3.113304e+09	2.395274e+08	11	-5.701092e+08	-6.094081e+08	-5.785905e+08	-5.931616e+08	-5.756676e+08	-5.853874e+08	1.422342e+07
2	0.767676	0.092776	0.029941	0.004852	2	30	NaN	{'max_features': 2, 'n_estimators': 30}	-2.715244e+09	-2.902911e+09	...	-2.802893e+09	1.542347e+08	8	-4.472685e+08	-4.271909e+08	-4.250249e+08	-4.289724e+08	-4.274656e+08	-4.311845e+08	8.140110e+06
3	0.107399	0.002467	0.003590	0.000482	4	3	NaN	{'max_features': 4, 'n_estimators': 3}	-3.698121e+09	-3.903447e+09	...	-3.666869e+09	2.123558e+08	15	-9.640156e+08	-9.257596e+08	-1.003782e+09	-9.539664e+08	-8.817442e+08	-9.458535e+08	4.065737e+07
4	0.348854	0.003953	0.009683	0.000399	4	10	NaN	{'max_features': 4, 'n_estimators': 10}	-2.721300e+09	-2.868056e+09	...	-2.788953e+09	1.090418e+08	7	-5.299723e+08	-4.985359e+08	-4.892567e+08	-5.204567e+08	-5.260378e+08	-5.128519e+08	1.604050e+07
5	1.058535	0.007244	0.027150	0.000750	4	30	NaN	{'max_features': 4, 'n_estimators': 30}	-2.412966e+09	-2.617706e+09	...	-2.575291e+09	1.220214e+08	3	-4.002692e+08	-3.978664e+08	-3.819384e+08	-4.050417e+08	-3.929495e+08	-3.956131e+08	7.870790e+06
6	0.157606	0.014043	0.003394	0.000200	6	3	NaN	{'max_features': 6, 'n_estimators': 3}	-3.292618e+09	-3.698856e+09	...	-3.535907e+09	1.405519e+08	14	-8.680883e+08	-9.112255e+08	-9.112524e+08	-1.004211e+09	-9.125849e+08	-9.214724e+08	4.468549e+07
7	0.491094	0.019871	0.009585	0.000372	6	10	NaN	{'max_features': 6, 'n_estimators': 10}	-2.553328e+09	-2.791239e+09	...	-2.750653e+09	1.162331e+08	6	-4.944491e+08	-4.779433e+08	-4.857806e+08	-5.040733e+08	-5.053104e+08	-4.935113e+08	1.052407e+07
8	1.446224	0.011595	0.027342	0.000976	6	30	NaN	{'max_features': 6, 'n_estimators': 30}	-2.358673e+09	-2.551963e+09	...	-2.516006e+09	1.406579e+08	2	-3.919785e+08	-3.910130e+08	-3.912702e+08	-3.954620e+08	-3.815968e+08	-3.902641e+08	4.618501e+06
9	0.184060	0.003479	0.003398	0.000202	8	3	NaN	{'max_features': 8, 'n_estimators': 3}	-3.214309e+09	-3.749000e+09	...	-3.492941e+09	2.259341e+08	13	-8.754845e+08	-9.164619e+08	-9.071570e+08	-9.245905e+08	-8.557108e+08	-8.958809e+08	2.609451e+07
10	0.613070	0.003497	0.009779	0.000676	8	10	NaN	{'max_features': 8, 'n_estimators': 10}	-2.524796e+09	-2.749792e+09	...	-2.695701e+09	1.184566e+08	4	-4.814605e+08	-4.924878e+08	-4.953451e+08	-5.081368e+08	-5.055688e+08	-4.965998e+08	9.604441e+06
11	1.843982	0.014717	0.026444	0.000943	8	30	NaN	{'max_features': 8, 'n_estimators': 30}	-2.375197e+09	-2.517777e+09	...	-2.506430e+09	1.330733e+08	1	-3.884794e+08	-3.844676e+08	-3.707626e+08	-3.920550e+08	-3.907572e+08	-3.853043e+08	7.714260e+06
12	0.105807	0.003103	0.003793	0.000399	2	3	False	{'bootstrap': False, 'max_features': 2, 'n_est...	-3.771914e+09	-3.740538e+09	...	-3.750444e+09	9.489691e+07	17	-0.000000e+00	-0.000000e+00	-0.000000e+00	-0.000000e+00	-0.000000e+00	0.000000e+00	0.000000e+00
13	0.347154	0.003689	0.010282	0.000509	2	10	False	{'bootstrap': False, 'max_features': 2, 'n_est...	-2.674186e+09	-2.990817e+09	...	-2.920952e+09	1.338905e+08	10	-4.739042e+02	-1.947960e+03	-1.514005e+02	-0.000000e+00	-7.418622e-01	-5.148012e+02	7.371504e+02
14	0.137443	0.003622	0.003394	0.000200	3	3	False	{'bootstrap': False, 'max_features': 3, 'n_est...	-3.262696e+09	-3.673369e+09	...	-3.667646e+09	3.201594e+08	16	-0.000000e+00	-0.000000e+00	-0.000000e+00	-0.000000e+00	-0.000000e+00	0.000000e+00	0.000000e+00
15	0.458953	0.010759	0.010680	0.000400	3	10	False	{'bootstrap': False, 'max_features': 3, 'n_est...	-2.672303e+09	-2.797555e+09	...	-2.819182e+09	1.591897e+08	9	-0.000000e+00	-0.000000e+00	-0.000000e+00	-0.000000e+00	-0.000000e+00	0.000000e+00	0.000000e+00
16	0.170678	0.001511	0.004197	0.000249	4	3	False	{'bootstrap': False, 'max_features': 4, 'n_est...	-3.360936e+09	-3.505000e+09	...	-3.424422e+09	1.686945e+08	12	-0.000000e+00	-0.000000e+00	-0.000000e+00	-0.000000e+00	-0.000000e+00	0.000000e+00	0.000000e+00
17	0.566040	0.004567	0.011083	0.000581	4	10	False	{'bootstrap': False, 'max_features': 4, 'n_est...	-2.507133e+09	-2.778602e+09	...	-2.701759e+09	1.294327e+08	5	-0.000000e+00	-0.000000e+00	-0.000000e+00	-0.000000e+00	-0.000000e+00	0.000000e+00	0.000000e+00
18 rows × 23 columns

可以看出,一共有18行,代表18中引數的組合方式。

將超引數引數的組合方式和想應的誤差值

for mean_score, params in zip(cv["mean_test_score"], cv["params"]):
    print(np.sqrt(-mean_score), params)

# out:
64338.12255034114 {'max_features': 2, 'n_estimators': 3}
55796.98682826028 {'max_features': 2, 'n_estimators': 10}
52942.353620944974 {'max_features': 2, 'n_estimators': 30}
60554.68102018701 {'max_features': 4, 'n_estimators': 3}
52810.54319153441 {'max_features': 4, 'n_estimators': 10}
50747.320348497095 {'max_features': 4, 'n_estimators': 30}
59463.49112391854 {'max_features': 6, 'n_estimators': 3}
52446.67106233131 {'max_features': 6, 'n_estimators': 10}
50159.80848511047 {'max_features': 6, 'n_estimators': 30}
59101.107722861816 {'max_features': 8, 'n_estimators': 3}
51920.13895373909 {'max_features': 8, 'n_estimators': 10}
50064.25380671872 {'max_features': 8, 'n_estimators': 30}
61240.86657596463 {'bootstrap': False, 'max_features': 2, 'n_estimators': 3}
54045.83360004432 {'bootstrap': False, 'max_features': 2, 'n_estimators': 10}
60561.09106204302 {'bootstrap': False, 'max_features': 3, 'n_estimators': 3}
53095.965054469176 {'bootstrap': False, 'max_features': 3, 'n_estimators': 10}
58518.56200119758 {'bootstrap': False, 'max_features': 4, 'n_estimators': 3}
51978.44739222989 {'bootstrap': False, 'max_features': 4, 'n_estimators': 10}

最好的是’max_features’: 8, ‘n_estimators’: 30, 誤差為50064,結果要好於預設的超引數,誤差為52634

  • RandomizedSearchCV
    當要探索相對較少的幾個超引數組合時,GridSearchCV已經可以了,但當要搜尋的空間(比如有很多超引數)很大時,使用RandomizedSearchCV更好。RandomizedSearchCV的使用方法其實是和GridSearchCV一致的,但它以隨機在引數空間中取樣的方式代替了GridSearchCV對於引數的網格搜尋,在對於有連續變數的引數時,RandomizedSearchCV會將其當作一個分佈進行取樣這是網格搜尋做不到的,它的搜尋能力取決於設定的n_iter引數.

http://scikit-learn.org/stable/modules/generated/sklearn.model_selection.RandomizedSearchCV.html

2. 模型組合

  • ensemble methods,
    另一種調整模型的方法是將表現好的模型組合起來。比如隨機森林就是講決策樹組合起來。