09.尋找最好的超引數
阿新 • • 發佈:2020-11-26
import numpy as np import matplotlib import matplotlib.pyplot as plt from sklearn import datasets from sklearn.model_selection import train_test_split from sklearn.neighbors import KNeighborsClassifier from sklearn.metrics import accuracy_score
1、獲取資料
digits = datasets.load_digits() X = digits.data y= digits.target
2、分割資料,得到訓練集和測試集
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=666)
3、手動尋找
# def temp(): # knn_clf = KNeighborsClassifier(3) # knn_clf.fit(X_train, y_train) # y_predict = knn_clf.predict(X_test) # accuracy_score(y_test, y_predict)# # 尋找最好的k # best_score = 0.0 # best_k = -1 # for k in range(1,11): # knn_clf = KNeighborsClassifier(k) # knn_clf.fit(X_train, y_train) # y_predict = knn_clf.predict(X_test) # score= accuracy_score(y_test, y_predict) # if score > best_score:# best_k = k # best_score = score # print("best_k:", best_k) # print("best_score:", best_score) # # 考慮距離?不考慮距離? # best_method = "" # best_score = 0.0 # best_k = -1 # for method in ["uniform", "distance"]: # for k in range(1,11): # knn_clf = KNeighborsClassifier(n_neighbors=k, weights=method) # knn_clf.fit(X_train, y_train) # y_predict = knn_clf.predict(X_test) # score= accuracy_score(y_test, y_predict) # if score > best_score: # best_k = k # best_score = score # best_method = method # print("best_k:", best_k) # print("best_score:", best_score) # print("best_method:", best_method) # # # 探索明可夫斯基距離相應的p # # 尋找最好的超引數 Grid Search
3、超引數配置
param_grid = [ { "weights":["uniform"], "n_neighbors":[i for i in range(1,11)] }, { "weights":["distance"], "n_neighbors":[i for i in range(1,11)], "p":[i for i in range(1,6)] }]
4、例項化分類器
knn_clf = KNeighborsClassifier()
5、為分類器和超引數搭建模型
from sklearn.model_selection import GridSearchCV grid_search = GridSearchCV(knn_clf, param_grid, n_jobs=-1, verbose=2)
6、例項化模型(多種引數配置的分類器)fit訓練集
# 本質上是將訓練集進一步分為訓練集和測試集,得到最好的引數配置
# 因為要不斷嘗試各種引數交叉驗證,所以非常耗時
grid_search.fit(X_train, y_train)
7、最終拿到最佳引數配置分類器 best_estimator_
knn_clf = grid_search.best_estimator_
8、使用最佳分類器對測試集預測
y_predict = knn_clf.predict(X_test)
9、列印準確率
print(accuracy_score(y_test, y_predict))