鳶尾花分類(基於neighbors模的塊knn)
阿新 • • 發佈:2022-03-27
""" 近鄰分類演算法是在 neighbors 模組的 KNeighborsClassifier 類中實現的。 我們需要將這個類例項化為一個物件,然後才能使用這個模型。這時我們需要設定模型的引數。 KNeighborsClassifier 最重要的引數就是鄰居的數目,這裡我們設為 1 knn 物件對演算法進行了封裝,既包括用訓練資料構建模型的演算法,也包括對新資料點進行 預測的演算法。它還包括演算法從訓練資料中提取的資訊。對於 KNeighborsClassifier 來說, 裡面只儲存了訓練集。 想要基於訓練集來構建模型,需要呼叫 knn 物件的 fit 方法,輸入引數為 X_train 和 y_ train,二者都是 NumPy 陣列,前者包含訓練資料,後者包含相應的訓練標籤 """ from sklearn.neighbors import KNeighborsClassifier from sklearn.datasets import load_iris from sklearn.model_selection import train_test_split import pandas as pd import numpy as np import scipy as sp import IPython import sklearn iris_dataset = load_iris() X_train, X_test, y_train, y_test = train_test_split( iris_dataset['data'], iris_dataset['target'], random_state=0) print(f"X_test shape: {X_train.shape}") print(f"y_test shape: {y_train.shape}") print(f"X_test shape: {X_test.shape}") print(f"y_test shape: {y_test.shape}") knn = KNeighborsClassifier(n_neighbors=1) knn.fit(X_train, y_train) # fit 方法返回的是 knn 物件本身並做原處修改,因此我們得到了分類器的字串表示。 X_new = np.array([[5, 2.9, 1, 0.2]]) # 我們將這朵花的測量資料轉換為二維 NumPy 陣列的一行,這是因為 scikit-learn的輸入資料必須是二維陣列。 print(f"X_new.shape: {X_new.shape}") print(X_new) #我們呼叫 knn 物件的 predict 方法來進行預測: prediction = knn.predict(X_new) print(f"Prediction: {prediction}") print(f"Prediction target name: {iris_dataset['target_names'][prediction]}") """ 評估模型 這裡需要用到之前建立的測試集。這些資料沒有用於構建模型,但我們知道測試集中每朵鳶尾花的實際品種。 因此,我們可以對測試資料中的每朵鳶尾花進行預測,並將預測結果與標籤(已知的品種)進行對比。 我們可以通過計算精度(accuracy)來衡量模型的優劣,精度就是品種預測正確的花所佔的比例 """ y_pred = knn.predict(X_test) print(f"Test set predictions:\n {y_pred}") print("Test set sore {:.2f}".format(np.mean(y_pred == y_test)))