sklearn庫學習之K-NN演算法
k近鄰分類與k近鄰迴歸
import matplotlib.pyplot as plt from sklearn.neighbors import KNeighborsRegressor from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import train_test_split import mglearn import numpy as np ############# X,y = mglearn.datasets.make_forge() X_train,X_test,y_train,y_test = train_test_split(X,y,random_state = 0) clf = KNeighborsClassifier(n_neighbors = 3) clf.fit(X_train,y_train) print("Test set predictions:{}".format(clf.predict(X_test))) print("Test set accuracy:{:.2f}".format(clf.score(X_test,y_test))) fig, axes = plt.subplots(1,3,figsize = (10,3)) for n_neighbors,ax in zip([1,3,9],axes): clf = KNeighborsClassifier(n_neighbors = n_neighbors).fit(X,y) #畫圖,決策邊界視覺化 mglearn.plots.plot_2d_separator(clf,X,fill = True, eps = 0.5,ax = ax, alpha = 0.4) mglearn.discrete_scatter(X[:,0],X[:,1],y,ax = ax)#標點 ax.set_title("{} neighbor(s)".format(n_neighbors)) ax.set_xlabel("feature 0") ax.set_ylabel("feature 1") ax.legend(loc = 3) ############# from sklearn.datasets import load_breast_cancer cancer = load_breast_cancer() X_train,X_test,y_train,y_test = train_test_split( cancer.data,cancer.target,stratify = cancer.target,random_state = 66) training_accuracy = [] test_accuracy = [] neighbors_settings = range(1,11) for n_neighbors in neighbors_settings: clf = KNeighborsClassifier(n_neighbors = n_neighbors) clf.fit(X_train,y_train) training_accuracy.append(clf.score(X_train,y_train)) test_accuracy.append(clf.score(X_test,y_test)) fig, ax = plt.subplots(1,1,figsize = (10,6)) plt.plot(neighbors_settings,training_accuracy, label = "training accuracy") plt.plot(neighbors_settings,test_accuracy, label = 'test accuracy') plt.xlabel("n_neighbors") plt.ylabel("Accuracy") plt.legend() ########## X,y = mglearn.datasets.make_wave(n_samples=40) X_train,X_test,y_train,y_test = train_test_split(X,y,random_state = 0) fig,axes = plt.subplots(1,3,figsize=(15,4)) line = np.linspace(-3,3,1000).reshape(-1,1) for n_neighbors,ax in zip([1,3,9],axes): reg = KNeighborsRegressor(n_neighbors = n_neighbors) reg.fit(X_train,y_train) print("Test set predictions:{}".format(reg.predict(X_test))) print("Test set accuracy:{:.2f}".format(reg.score(X_test,y_test))) ax.plot(line,reg.predict(line)) ax.plot(X_train,y_train,'^',c = mglearn.cm2(0),markersize = 8) ax.plot(X_test,y_test,'.',c = mglearn.cm2(1),markersize = 8) ax.set_title("{}neighbor(s)\n train score:{:.2f} test score:{:.2f}".format( n_neighbors,reg.score(X_train,y_train),reg.score(X_test,y_test))) ax.set_xlabel('Feature') ax.set_ylabel('Target') ax.legend(['Model predictions','Training data/target','Test data/target'],loc = 'best')
對於程式碼中函式用法的疑惑
-
python中關於圖例legend在圖外的畫法簡析
https://blog.csdn.net/yywan1314520/article/details/53740001/ -
[python] pandas plot( )畫圖命令總結
https://blog.csdn.net/u013084616/article/details/79064408 -
Python之matplotlib基礎
https://www.cnblogs.com/liutongqing/p/6985805.html -
tensorflow的reshape操作tf.reshape()
https://blog.csdn.net/m0_37592397/article/details/78695318 -
numpy.linspace使用詳解
https://blog.csdn.net/you_are_my_dream/article/details/53493752 -
fig,ax = plt.subplots()的理解
https://www.jianshu.com/p/decf22446316 -
train_test_split用法
https://blog.csdn.net/mrxjh/article/details/78481578 -
make_blobs聚類資料生成器
https://blog.csdn.net/kevinelstri/article/details/52622960 -
sklearn提供的自帶的資料集
-
Python DeprecationWarning 型別錯誤
https://blog.csdn.net/qq_38734403/article/details/79779713