scikit-learn 中KNN分類繪圖
阿新 • • 發佈:2018-11-29
- scikit-learn 中KNN分類繪圖
- 參考連結:
- KNN相關的類庫概述:
https://www.cnblogs.com/pinard/p/6065607.html - 下載的toy資料集:
https://blog.csdn.net/sa14023053/article/details/52086695 - plt.scatter各引數詳解:
https://blog.csdn.net/weixin_40713373/article/details/80024583
# -*- coding: utf-8 -*- """ Created on Sun Nov 25 15:55:09 2018 @author: muli """ import numpy as np import matplotlib.pyplot as plt from sklearn.datasets.samples_generator import make_classification from sklearn import neighbors # 繪製背景的邊界 from matplotlib.colors import ListedColormap # 生成隨機資料 # X為樣本特徵,Y為樣本類別輸出, 共1000個樣本,每個樣本2個特徵, # 輸出有3個類別,沒有冗餘特徵,每個類別一個簇 X, Y = make_classification(n_samples=1000, n_features=2, n_redundant=0, n_clusters_per_class=1, n_classes=3,random_state=1) # X 為樣本的特徵,此案例中,只定義為兩類 # marker='o':圓形 # c=Y:顏色,順序或顏色順序, `c`可以是一個二維陣列,其中的行是RGB或RGBA,但是,包括單個的情況行為所有點指定相同的顏色 # 可認為 c 顏色由聚類的簇 n_classes=n 自動決定 plt.scatter(X[:, 0], X[:, 1], marker='o', c=Y) plt.show() # KNeighborsClassifier 分類器 clf = neighbors.KNeighborsClassifier(n_neighbors = 15 , weights='distance') clf.fit(X, Y) print("------------------------------") # 顏色濃 cmap_light = ListedColormap(['#FFAAAA', '#AAFFAA', '#AAAAFF']) # 顏色淡 cmap_bold = ListedColormap(['#FF0000', '#00FF00', '#0000FF']) # 確認訓練集的邊界 # 由 X特徵的最值確定 確定 x_min, x_max = X[:, 0].min() - 1, X[:, 0].max() + 1 y_min, y_max = X[:, 1].min() - 1, X[:, 1].max() + 1 # 生成隨機資料來做測試集,然後作預測 # x_min--x_max,步長為 0.02----等差數列 # xx,yy分別是X的兩個特徵的其中一個 xx, yy = np.meshgrid(np.arange(x_min, x_max, 0.02), np.arange(y_min, y_max, 0.02)) # np.r_是按列連線兩個矩陣,就是把兩矩陣上下相加,要求列數相等。 # np.c_是按行連線兩個矩陣,就是把兩矩陣左右相加,要求行數相等。 # Z為測試集的資料 Z = clf.predict(np.c_[xx.ravel(), yy.ravel()]) # 畫出測試集資料 Z = Z.reshape(xx.shape) plt.figure() # plt.pcolormesh(xx, yy, y_predict, cmap=cmap_light) # 作用:畫出不同型別資料的色彩範圍--區域 # xx,yy:影象區域內的取樣點--組織成一個點 # y_predict:根據取樣點計算出的每個點所屬的類別 # camp:將相應的值對映到顏色 plt.pcolormesh(xx, yy, Z, cmap=cmap_light) # 也畫出所有的訓練集資料 plt.scatter(X[:, 0], X[:, 1], c=Y, cmap=cmap_bold) plt.xlim(xx.min(), xx.max()) plt.ylim(yy.min(), yy.max()) plt.title("3-Class classification (k = 15, weights = 'distance')" )
- 如圖所示: