GridSearchCV網格搜尋得到最佳超引數, 在K近鄰演算法中的應用
阿新 • • 發佈:2020-07-29
最近在學習機器學習中的K近鄰演算法,KNeighborsClassifier 看似簡單實則裡面有很多的引數配置, 這些引數直接影響到預測的準確率. 很自然的問題就是如何找到最優引數配置? 這就需要用到GridSearchCV 網格搜尋模型.
在沒有學習到GridSearchCV 網格搜尋模型之前, 尋找最優引數配置是通過人為改變引數, 來觀察預測結果準確率的. 具體步驟如下:
- 修改引數配置
- fit 訓練集
- 預測測試集
- 預測結果與真實結果對比
- 重複上述步驟
GridSearchCV 網格搜尋模型尋找最優引數的步驟如下:
- 將各種引數配置封裝為列表
- 例項化分類器
- 使用GridSearchCV 為分類器和引數建模
- 例項化模型, 並用新的模型物件fit訓練集
- 得到最好的引數配置
- 用最優引數去預測資料
於是我的疑問就來了,GridSearchCV 並沒有去預測測試集,進而得到預測結果,並在與真實結果的對比中找到最優的引數配置, 沒有這個步驟,它是怎麼得到最優引數的? 搜尋了很多,終於在這個網頁中得到了想要的資訊: python – GridSearchCV是否執行交叉驗證?http://www.cocoachina.com/articles/67515
簡單說就是我們把訓練集傳遞給GridSearchCV, 它會進一步將訓練集分為訓練集和測試集, 然後通過不斷調整超引數, 進行交叉驗證, 最後獲得最優引數.
GridSearchCV會主動將資料分為訓練集和測試集,這就是原因所在了.
程式碼實現:
1 from sklearn import datasets
2 from sklearn.model_selection import train_test_split
3 from sklearn.neighbors import KNeighborsClassifier
4 from sklearn.metrics import accuracy_score
5 from sklearn.model_selection import GridSearchCV
6
7
8 # 1/獲取資料
9 digits = datasets.load_digits()
10 X = digits.data
11 y = digits.target
12
13 # 2/分割資料,得到訓練集和測試集
14 X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=1)
15
16
17 # 3/超引數配置
18 param_grid = [
19 {
20 "weights":["uniform"],
21 "n_neighbors":[i for i in range(1,11)]
22 },
23 {
24 "weights":["distance"],
25 "n_neighbors":[i for i in range(1,11)],
26 "p":[i for i in range(1,6)]
27 }
28 ]
29
30
31 # 4/為分類器和超引數搭建模型
32 knn_clf = KNeighborsClassifier()
33 grid_search = GridSearchCV(knn_clf, param_grid, n_jobs=-1, verbose=2)
34
35 # 5/例項化模型(多種引數配置的分類器)fit訓練集,
36 # 本質上是將訓練集進一步分為訓練集和測試集,得到最好的引數配置
37 # 因為要不斷嘗試各種引數交叉驗證,所以非常耗時
38 grid_search.fit(X_train, y_train)
39
40 # 6/
41 # 最終拿到最佳引數配置分類器 best_estimator_
42 knn_clf = grid_search.best_estimator_
43
44 # 7/使用最佳分類器對測試集預測
45 y_predict = knn_clf.predict(X_test)
46
47 # 8/得到準確率
48 accuracy_score(y_test, y_predict))