sklearn 中的交叉驗證
sklearn中的交叉驗證(Cross-Validation)
sklearn是利用python進行機器學習中一個非常全面和好用的第三方庫,用過的都說好。今天主要記錄一下sklearn中關於交叉驗證的各種用法,主要是對sklearn官方文檔 Cross-validation: evaluating estimator performance進行講解,英文水平好的建議讀官方文檔,裏面的知識點很詳細。
1. cross_val_score
對數據集進行指定次數的交叉驗證並為每次驗證效果評測
其中,score 默認是以 scoring=’f1_macro’進行評測的,余外針對分類或回歸還有:
這需要from sklearn import metrics ,通過在cross_val_score 指定參數來設定評測標準;
當cv 指定為int 類型時,默認使用KFold 或StratifiedKFold 進行數據集打亂,下面會對KFold 和StratifiedKFold 進行介紹。
In [15]: from sklearn.model_selection import cross_val_score In [16]: clf = svm.SVC(kernel=‘linear‘, C=1) In [17]: scores = cross_val_score(clf, iris.data, iris.target, cv=5) In [18]: scores Out[18]: array([ 0.96666667, 1. , 0.96666667, 0.96666667, 1. ]) In [19]: scores.mean() Out[19]: 0.98000000000000009
除使用默認交叉驗證方式外,可以對交叉驗證方式進行指定,如驗證次數,訓練集測試集劃分比例等
In [20]: from sklearn.model_selection import ShuffleSplit In [21]: n_samples = iris.data.shape[0] In [22]: cv = ShuffleSplit(n_splits=3, test_size=.3, random_state=0) In [23]: cross_val_score(clf, iris.data, iris.target, cv=cv) Out[23]: array([ 0.97777778, 0.97777778, 1. ])
2. cross_val_predict
cross_val_predict 與cross_val_score 很相像,不過不同於返回的是評測效果,cross_val_predict 返回的是estimator 的分類結果(或回歸值),這個對於後期模型的改善很重要,可以通過該預測輸出對比實際目標值,準確定位到預測出錯的地方,為我們參數優化及問題排查十分的重要。
In [28]: from sklearn.model_selection import cross_val_predict In [29]: from sklearn import metrics In [30]: predicted = cross_val_predict(clf, iris.data, iris.target, cv=10) In [31]: predicted Out[31]: array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2]) In [32]: metrics.accuracy_score(iris.target, predicted) Out[32]: 0.96666666666666667
3. KFold
K折交叉驗證,這是將數據集分成K份的官方給定方案,所謂K折就是將數據集通過K次分割,使得所有數據既在訓練集出現過,又在測試集出現過,當然,每次分割中不會有重疊。相當於無放回抽樣。
In [33]: from sklearn.model_selection import KFold In [34]: X = [‘a‘,‘b‘,‘c‘,‘d‘] In [35]: kf = KFold(n_splits=2) In [36]: for train, test in kf.split(X): ...: print train, test ...: print np.array(X)[train], np.array(X)[test] ...: print ‘\n‘ ...: [2 3] [0 1] [‘c‘ ‘d‘] [‘a‘ ‘b‘] [0 1] [2 3] [‘a‘ ‘b‘] [‘c‘ ‘d‘]
4. LeaveOneOut
LeaveOneOut 其實就是KFold 的一個特例,因為使用次數比較多,因此獨立的定義出來,完全可以通過KFold 實現。
In [37]: from sklearn.model_selection import LeaveOneOut In [38]: X = [1,2,3,4] In [39]: loo = LeaveOneOut() In [41]: for train, test in loo.split(X): ...: print train, test ...: [1 2 3] [0] [0 2 3] [1] [0 1 3] [2] [0 1 2] [3] #使用KFold實現LeaveOneOtut In [42]: kf = KFold(n_splits=len(X)) In [43]: for train, test in kf.split(X): ...: print train, test ...: [1 2 3] [0] [0 2 3] [1] [0 1 3] [2] [0 1 2] [3]
5. LeavePOut
這個也是KFold 的一個特例,用KFold 實現起來稍麻煩些,跟LeaveOneOut 也很像。
In [44]: from sklearn.model_selection import LeavePOut In [45]: X = np.ones(4) In [46]: lpo = LeavePOut(p=2) In [47]: for train, test in lpo.split(X): ...: print train, test ...: [2 3] [0 1] [1 3] [0 2] [1 2] [0 3] [0 3] [1 2] [0 2] [1 3] [0 1] [2 3]
6. ShuffleSplit
ShuffleSplit 咋一看用法跟LeavePOut 很像,其實兩者完全不一樣,LeavePOut 是使得數據集經過數次分割後,所有的測試集出現的元素的集合即是完整的數據集,即無放回的抽樣,而ShuffleSplit 則是有放回的抽樣,只能說經過一個足夠大的抽樣次數後,保證測試集出現了完成的數據集的倍數。
In [48]: from sklearn.model_selection import ShuffleSplit In [49]: X = np.arange(5) In [50]: ss = ShuffleSplit(n_splits=3, test_size=.25, random_state=0) In [51]: for train_index, test_index in ss.split(X): ...: print train_index, test_index ...: [1 3 4] [2 0] [1 4 3] [0 2] [4 0 2] [1 3]
7. StratifiedKFold
對測試集合進行無放回抽樣
In [52]: from sklearn.model_selection import StratifiedKFold In [53]: X = np.ones(10) In [54]: y = [0,0,0,0,1,1,1,1,1,1] In [55]: skf = StratifiedKFold(n_splits=3) In [56]: for train, test in skf.split(X,y): ...: print train, test ...: [2 3 6 7 8 9] [0 1 4 5] [0 1 3 4 5 8 9] [2 6 7] [0 1 2 4 5 6 7] [3 8 9]
原文:https://blog.csdn.net/xiaodongxiexie/article/details/71915259
sklearn 中的交叉驗證