Sklearn實現最近鄰演算法(KNN)
阿新 • • 發佈:2021-01-13
技術標籤:機器學習# Sklearn筆記sklearn機器學習knn
>>> from sklearn.neighbors import NearestNeighbors
>>> import numpy as np
>>> X = np.array([[-1, -1], [-2, -1], [-3, -2], [1, 1], [2, 1], [3, 2]])
>>> nbrs = NearestNeighbors(n_neighbors=2, algorithm='ball_tree').fit(X)
>> > distances, indices = nbrs.kneighbors(X)
>>> indices
array([[0, 1],
[1, 0],
[2, 1],
[3, 4],
[4, 3],
[5, 4]]...)
>>> distances
array([[0. , 1. ],
[0. , 1. ],
[0. , 1.41421356],
[0. , 1. ],
[0. , 1. ],
[0. , 1.41421356]])Copy
因為查詢集匹配訓練集,每個點的最近鄰點是其自身,距離為0。
還可以有效地生成一個稀疏圖來標識相連點之間的連線情況:
>>> nbrs.kneighbors_graph(X).toarray()
array([[ 1., 1., 0., 0., 0., 0.],
[ 1., 1., 0., 0., 0., 0.],
[ 0., 1., 1., 0., 0., 0.],
[ 0., 0., 0., 1. , 1., 0.],
[ 0., 0., 0., 1., 1., 0.],
[ 0., 0., 0., 0., 1., 1.]])