1. 程式人生 > 實用技巧 >09.尋找最好的超引數

09.尋找最好的超引數

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn.model_selection import train_test_split
from sklearn.neighbors import KNeighborsClassifier
from sklearn.metrics import accuracy_score

1、獲取資料

digits = datasets.load_digits()
X = digits.data
y 
= digits.target

2、分割資料,得到訓練集和測試集

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)

3、手動尋找

# def temp():
    # knn_clf = KNeighborsClassifier(3)
    # knn_clf.fit(X_train, y_train)
    # y_predict = knn_clf.predict(X_test)
    # accuracy_score(y_test, y_predict)
# # 尋找最好的k # best_score = 0.0 # best_k = -1 # for k in range(1,11): # knn_clf = KNeighborsClassifier(k) # knn_clf.fit(X_train, y_train) # y_predict = knn_clf.predict(X_test) # score= accuracy_score(y_test, y_predict) # if score > best_score:
# best_k = k # best_score = score # print("best_k:", best_k) # print("best_score:", best_score) # # 考慮距離?不考慮距離? # best_method = "" # best_score = 0.0 # best_k = -1 # for method in ["uniform", "distance"]: # for k in range(1,11): # knn_clf = KNeighborsClassifier(n_neighbors=k, weights=method) # knn_clf.fit(X_train, y_train) # y_predict = knn_clf.predict(X_test) # score= accuracy_score(y_test, y_predict) # if score > best_score: # best_k = k # best_score = score # best_method = method # print("best_k:", best_k) # print("best_score:", best_score) # print("best_method:", best_method) # # # 探索明可夫斯基距離相應的p # # 尋找最好的超引數 Grid Search

3、超引數配置

param_grid = [
    {
        "weights":["uniform"],
        "n_neighbors":[i for i in range(1,11)]
    },
    {
        "weights":["distance"],
        "n_neighbors":[i for i in range(1,11)],
        "p":[i for i in range(1,6)]
    }]

4、例項化分類器

knn_clf = KNeighborsClassifier()

5、為分類器和超引數搭建模型

from sklearn.model_selection import GridSearchCV
grid_search = GridSearchCV(knn_clf, param_grid, n_jobs=-1, verbose=2)

6、例項化模型(多種引數配置的分類器)fit訓練集

# 本質上是將訓練集進一步分為訓練集和測試集,得到最好的引數配置
# 因為要不斷嘗試各種引數交叉驗證,所以非常耗時

grid_search.fit(X_train, y_train)

7、最終拿到最佳引數配置分類器 best_estimator_

knn_clf = grid_search.best_estimator_

8、使用最佳分類器對測試集預測

y_predict = knn_clf.predict(X_test)

9、列印準確率

print(accuracy_score(y_test, y_predict))