機器學習 Python scikit-learn 中文文件(7)模型選擇: 選擇合適的估計器及其引數
模型選擇: 選擇合適的估計器及其引數
與官方文件完美匹配的中文文件,請訪問 https://www.studyai.cn
Score, 和 cross-validated scores
交叉驗證生成器
網格搜尋與交叉驗證估計器
網格搜尋
自帶交叉驗證的估計器
模型選擇: 選擇合適的估計器及其引數
Score, 和 cross-validated scores
正如我們所看到的,每個估計器都暴露了一個 score 方法,可以根據新的資料判斷擬合(或預測)的質量。 得分越大就表示模型擬合的越好.
from sklearn import datasets, svm
digits = datasets.load_digits()
X_digits = digits.data
y_digits = digits.target
svc = svm.SVC(C=1, kernel=‘linear’)
svc.fit(X_digits[:-100], y_digits[:-100]).score(X_digits[-100:], y_digits[-100:])
0.98
為了對模型的預測精度有一個更好的度量(我們可以把它作為模型擬合優度的代理),我們可以連續地將資料分割成用於訓練和測試的折(folds):
import numpy as np
X_folds = np.array_split(X_digits, 3)
y_folds = np.array_split(y_digits, 3)
scores = list()
for k in range(3):
… # We use ‘list’ to copy, in order to ‘pop’ later on
… X_train = list(X_folds)
… X_test = X_train.pop(k)
… X_train = np.concatenate(X_train)
… y_train = list(y_folds)
… y_test = y_train.pop(k)
… y_train = np.concatenate(y_train)
… scores.append(svc.fit(X_train, y_train).score(X_test, y_test))print(scores)
[0.934…, 0.956…, 0.939…]
上述方法被稱作K-折交叉驗證( KFold cross-validation).
交叉驗證生成器
Scikit-learn 提供了好多的類,這些類可以根據不同的交叉驗證策略來生成 train/test indices 的列表。
這些類提供了一個 split 方法,該方法接受將要被分割的輸入資料然後根據選定的交叉驗證策略在每個迭代步都產生train/test 集合的索引切片。
下面的例子展示了 split 方法的使用。
from sklearn.model_selection import KFold, cross_val_score
X = [“a”, “a”, “a”, “b”, “b”, “c”, “c”, “c”, “c”, “c”]
k_fold = KFold(n_splits=5)
for train_indices, test_indices in k_fold.split(X):
… print(‘Train: %s | test: %s’ % (train_indices, test_indices))
Train: [2 3 4 5 6 7 8 9] | test: [0 1]
Train: [0 1 4 5 6 7 8 9] | test: [2 3]
Train: [0 1 2 3 6 7 8 9] | test: [4 5]
Train: [0 1 2 3 4 5 8 9] | test: [6 7]
Train: [0 1 2 3 4 5 6 7] | test: [8 9]
然後,可以輕鬆地執行交叉驗證
[svc.fit(X_digits[train], y_digits[train]).score(X_digits[test], y_digits[test])
… for train, test in k_fold.split(X_digits)]
[0.963…, 0.922…, 0.963…, 0.963…, 0.930…]
交叉驗證得分(The cross-validation score)可以使用函式 cross_val_score 直接計算。 給定一個 estimator,一個cross-validation 物件以及一個輸入資料集,函式 cross_val_score 將重複的 把整個資料集分割成訓練集和測試集,然後每一個迭代步都在不同的訓練集上訓練模型,在不同的測試集上測試模型。 最終計算出所有迭代步的測試集上的平均分。
預設情況下,估計器的 score 方法 被用來計算模型在每個獨立的測試集上的得分。
參考 metrics module 學習更多的評分測度。
cross_val_score(svc, X_digits, y_digits, cv=k_fold, n_jobs=-1)
array([0.96388889, 0.92222222, 0.9637883 , 0.9637883 , 0.93036212])
n_jobs=-1 意味著計算將被分發到計算機的所有CPU kernels 上。
我們還可以為 scoring 引數 指定其他的評分方法,如 precision_macro。
cross_val_score(svc, X_digits, y_digits, cv=k_fold,
… scoring=‘precision_macro’)
array([0.96578289, 0.92708922, 0.96681476, 0.96362897, 0.93192644])
Cross-validation generators
KFold (n_splits, shuffle, random_state) StratifiedKFold (n_splits, shuffle, random_state) GroupKFold (n_splits)
Splits it into K folds, trains on K-1 and then tests on the left-out. Same as K-Fold but preserves the class distribution within each fold. Ensures that the same group is not in both testing and training sets.
ShuffleSplit (n_splits, test_size, train_size, random_state) StratifiedShuffleSplit GroupShuffleSplit
Generates train/test indices based on random permutation. Same as shuffle split but preserves the class distribution within each iteration. Ensures that the same group is not in both testing and training sets.
LeaveOneGroupOut () LeavePGroupsOut (n_groups) LeaveOneOut ()
Takes a group array to group observations. Leave P groups out. Leave one observation out.
LeavePOut § PredefinedSplit
Leave P observations out. Generates train/test indices based on predefined splits.
練習
…/…/images/sphx_glr_plot_cv_digits_001.png
在 digits 資料集上, 繪製帶有線性核函式的 SVC 估計器的交叉驗證得分。交叉驗證得分為縱軸,橫軸是引數 C (use a logarithmic grid of points, from 1 to 10).
import numpy as np
from sklearn.model_selection import cross_val_score
from sklearn import datasets, svm
digits = datasets.load_digits()
X = digits.data
y = digits.target
svc = svm.SVC(kernel=‘linear’)
C_s = np.logspace(-10, 0, 10)
練習題答案: Cross-validation on Digits Dataset Exercise
網格搜尋與交叉驗證估計器
網格搜尋
scikit-learn 提供了一個物件,當estimator在一個引數網格上根據給定的資料不斷擬合的時候, 該物件可以計算模型在每個網格節點的得分,並且選擇能夠最大化交叉驗證得分的那一組引數。這個 物件在構造的時候需要傳入一個estimator,而且向外提供了estimator的API:
from sklearn.model_selection import GridSearchCV, cross_val_score
Cs = np.logspace(-6, -1, 10)
clf = GridSearchCV(estimator=svc, param_grid=dict(C=Cs),
… n_jobs=-1)clf.fit(X_digits[:1000], y_digits[:1000])
GridSearchCV(cv=None,…clf.best_score_
0.925…clf.best_estimator_.C
0.0077…
Prediction performance on test set is not as good as on train set
clf.score(X_digits[1000:], y_digits[1000:])
0.943…
預設情況下, GridSearchCV 類使用3-fold交叉驗證。然而,如果它發現傳入的estimator是一個分類器而不是迴歸器,他將會使用 stratified 3-fold。 該預設值在sklearn0.22 版本 將會變為 5-fold cross-validation。
Nested cross-validation
cross_val_score(clf, X_digits, y_digits)
array([0.938…, 0.963…, 0.944…])
Two cross-validation loops are performed in parallel: one by the GridSearchCV estimator to set gamma and the other one by cross_val_score to measure the prediction performance of the estimator. The resulting scores are unbiased estimates of the prediction score on new data.
Warning You cannot nest objects with parallel computing (n_jobs different than 1).
自帶交叉驗證的估計器
設定引數的交叉驗證可以在逐演算法的基礎上更有效地完成。這就是為什麼對於某些估計器,Scikit-Learn 暴露 交叉驗證:評估估計器的效能 的估計器以便使用 cross-validation自動設定他們的引數
from sklearn import linear_model, datasets
lasso = linear_model.LassoCV(cv=3)
diabetes = datasets.load_diabetes()
X_diabetes = diabetes.data
y_diabetes = diabetes.target
lasso.fit(X_diabetes, y_diabetes)
LassoCV(alphas=None, copy_X=True, cv=3, eps=0.001, fit_intercept=True,
max_iter=1000, n_alphas=100, n_jobs=None, normalize=False,
positive=False, precompute=‘auto’, random_state=None,
selection=‘cyclic’, tol=0.0001, verbose=False)The estimator chose automatically its lambda:
lasso.alpha_
0.01229…
這些估計器的呼叫方式與對應的estimator是類似的,只是名字後面多了一個 ‘CV’。
練習
在糖尿病資料集上, 找到最優正則化引數: alpha.
Bonus: 你對你選擇的 alpha 引數有多信任 ?
from sklearn import datasets
from sklearn.linear_model import LassoCV
from sklearn.linear_model import Lasso
from sklearn.model_selection import KFold
from sklearn.model_selection import GridSearchCV
diabetes = datasets.load_diabetes()
練習題答案: Cross-validation on diabetes Dataset Exercise
© 2007 - 2018, scikit-learn developers (BSD License). Show this page source
Previous
Next