1. 程式人生 > >python中sklearn實現交叉驗證

python中sklearn實現交叉驗證

質量要比數量重要,就像一個本壘打勝過兩個雙打。——《螞蟻金服》

1、概述

在實驗資料分析中,有些演算法需要用現有的資料構建模型,如卷積神經網路(CNN),這類演算法稱為監督學習(Supervisied Learning)。構建模型需要的資料稱為訓練資料。

模型構建完後,需要利用資料驗證模型的正確性,這部分資料稱為測試資料。測試資料不能用於構建模型中,只能用於最後檢驗模型的準確性。

有時候模型的構建的過程中,也需要檢驗模型,輔助模型構建。所以會將訓練資料分為兩個部分,1)訓練資料;2)驗證資料。
將資料分類就要採用交叉驗證的方法,個人寫的交叉驗證演算法不可避免有一定缺陷,考慮使用強大sklearn包可以實現交叉驗證演算法。

2、python實現

請注意:以下的方法實現根據最新的sklearn版本實現,老版本的函式很多已經過期。

2.1 K次交叉檢驗(K-Fold Cross Validation)

K次交叉檢驗的大致思想是將資料大致分為K個子樣本,每次取一個樣本作為驗證資料,取餘下的K-1個樣本作為訓練資料。

from sklearn.model_selection import KFold
import numpy as np
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([1, 2, 3, 4])
kf = KFold(n_splits=2
) for train_index, test_index in kf.split(X): print("TRAIN:", train_index, "TEST:", test_index) X_train, X_test = X[train_index], X[test_index] y_train, y_test = y[train_index], y[test_index]

2.2 Stratified k-fold

StratifiedKFold()這個函式較常用,比KFold的優勢在於將k折資料按照百分比劃分資料集,每個類別百分比在訓練集和測試集中都是一樣,這樣能保證不會有某個類別的資料在訓練集中而測試集中沒有這種情況,同樣不會在訓練集中沒有全在測試集中,這樣會導致結果糟糕透頂。

from sklearn.model_selection import StratifiedKFold
import numpy as np

X = np.array([[1, 2], [3, 4], [1, 2], [3, 4]])
y = np.array([0, 0, 1, 1])
skf = StratifiedKFold(n_splits=2)
for train_index, test_index in skf.split(X, y):
    print("TRAIN:", train_index, "TEST:", test_index)
    X_train, X_test = X[train_index], X[test_index]
    y_train, y_test = y[train_index], y[test_index]

2.3 train_test_split

隨機根據比例分配訓練集和測試集。這個函式可以調整隨機種子。

import numpy as np
from sklearn.model_selection import train_test_split
X, y = np.arange(10).reshape((5, 2)), range(5)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.33, random_state=42)

3、總結

sklearn功能非常強大提供的交叉驗證函式也非常多,如:

           'BaseCrossValidator',
           'GridSearchCV',
           'TimeSeriesSplit',
           'KFold',
           'GroupKFold',
           'GroupShuffleSplit',
           'LeaveOneGroupOut',
           'LeaveOneOut',
           'LeavePGroupsOut',
           'LeavePOut',
           'ParameterGrid',
           'ParameterSampler',
           'PredefinedSplit',
           'RandomizedSearchCV',
           'ShuffleSplit',
           'StratifiedKFold',
           'StratifiedShuffleSplit',
           'check_cv',
           'cross_val_predict',
           'cross_val_score',
           'fit_grid_point',
           'learning_curve',
           'permutation_test_score',
           'train_test_split',
           'validation_curve'

感興趣的可以檢視sklearn原始碼。