2. KNN和KdTree算法實現
阿新 • • 發佈:2019-01-20
nload bug bre pen 操作性 nod 系列 特點 最近鄰
1. K近鄰算法(KNN)
2. KNN和KdTree算法實現
1. 前言
KNN一直是一個機器學習入門需要接觸的第一個算法,它有著簡單,易懂,可操作性強的一些特點。今天我久帶領大家先看看sklearn中KNN的使用,在帶領大家實現出自己的KNN算法。
2. KNN在sklearn中的使用
knn在sklearn中是放在sklearn.neighbors的包中的,我們今天主要介紹KNeighborsClassifier的分類器。
KNeighborsClassifier的主要參數是:
參數 | 意義 |
---|---|
n_neighbors | K值的選擇與樣本分布有關,一般選擇一個較小的K值,可以通過交叉驗證來選擇一個比較優的K值,默認值是5 |
weights | ‘uniform’是每個點權重一樣,‘distance’則權重和距離成反比例,即距離預測目標更近的近鄰具有更高的權重 |
algorithm | ‘brute’對應第一種蠻力實現,‘kd_tree’對應第二種KD樹實現,‘ball_tree’對應第三種的球樹實現, ‘auto’則會在上面三種算法中做權衡,選擇一個擬合最好的最優算法。 |
leaf_size | 這個值控制了使用KD樹或者球樹時, 停止建子樹的葉子節點數量的閾值。 |
metric | K近鄰法和限定半徑最近鄰法類可以使用的距離度量較多,一般來說默認的歐式距離(即p=2的閔可夫斯基距離)就可以滿足我們的需求。 |
p | p是使用距離度量參數 metric 附屬參數,只用於閔可夫斯基距離和帶權重閔可夫斯基距離中p值的選擇,p=1為曼哈頓距離, p=2為歐式距離。默認為2 |
我個人認為這些個參數,比較重要的應該屬n_neighbors、weights了,其他默認的也都沒太大問題。
3. KNN基礎版實現
直接看代碼如下,完整代碼GitHub:
def fit(self, X_train, y_train): self.X_train = X_train self.y_train = y_train def predict(self, X): # 取出n個點 knn_list = [] for i in range(self.n): dist = np.linalg.norm(X - self.X_train[i], ord=self.p) knn_list.append((dist, self.y_train[i])) for i in range(self.n, len(self.X_train)): max_index = knn_list.index(max(knn_list, key=lambda x: x[0])) dist = np.linalg.norm(X - self.X_train[i], ord=self.p) if knn_list[max_index][0] > dist: knn_list[max_index] = (dist, self.y_train[i]) # 統計 knn = [k[-1] for k in knn_list] return Counter(knn).most_common()[0][0]
我的接口設計都是按照sklearn的樣子設計的,fit方法其實主要用來接收參數了,沒有進行任何的處理。所有的操作都在predict中,著就會導致,我們對每個點預測的時候,時間消耗比較大。這個基礎版本大家看看就好,我想大家自己去寫,肯定也沒問題的。
4. KdTree版本實現
kd樹算法包括三步,第一步是建樹,第二部是搜索最近鄰,最後一步是預測。
4.1 構建kd樹
kd樹是一種對n維空間的實例點進行存儲,以便對其進行快速檢索的樹形結構。kd樹是二叉樹,構造kd樹相當於不斷的用垂直於坐標軸的超平面將n維空間進行劃分,構成一系列的n維超矩陣區域。
下面的流程圖更加清晰的描述了kd樹的構建過程:
kdtree樹的生成代碼:
# 建立kdtree
def create(self, dataSet, label, depth=0):
if len(dataSet) > 0:
m, n = np.shape(dataSet)
self.n = n
axis = depth % self.n
mid = int(m / 2)
dataSetcopy = sorted(dataSet, key=lambda x: x[axis])
node = Node(dataSetcopy[mid], label[mid], depth)
if depth == 0:
self.KdTree = node
node.lchild = self.create(dataSetcopy[:mid], label, depth+1)
node.rchild = self.create(dataSetcopy[mid+1:], label, depth+1)
return node
return None
4.2 kd樹搜索最近鄰和預測
當我們生成kd樹以後,就可以去預測測試集裏面的樣本目標點了。預測的過程如下:
- 對於一個目標點,我們首先在kd樹裏面找到包含目標點的葉子節點。以目標點為圓心,以目標點到葉子節點樣本實例的距離為半徑,得到一個超球體,最近鄰的點一定在這個超球體內部。
- 然後返回葉子節點的父節點,檢查另一個子節點包含的超矩形體是否和超球體相交,如果相交就到這個子節點尋找是否有更加近的近鄰,有的話就更新最近鄰,並且更新超球體。如果不相交那就簡單了,我們直接返回父節點的父節點,在另一個子樹繼續搜索最近鄰。
- 當回溯到根節點時,算法結束,此時保存的最近鄰節點就是最終的最近鄰。
kdtree樹的搜索代碼:
# 搜索kdtree的前count個近的點
def search(self, x, count = 1):
nearest = []
for i in range(count):
nearest.append([-1, None])
# 初始化n個點,nearest是按照距離遞減的方式
self.nearest = np.array(nearest)
def recurve(node):
if node is not None:
# 計算當前點的維度axis
axis = node.depth % self.n
# 計算測試點和當前點在axis維度上的差
daxis = x[axis] - node.data[axis]
# 如果小於進左子樹,大於進右子樹
if daxis < 0:
recurve(node.lchild)
else:
recurve(node.rchild)
# 計算預測點x到當前點的距離dist
dist = np.sqrt(np.sum(np.square(x - node.data)))
for i, d in enumerate(self.nearest):
# 如果有比現在最近的n個點更近的點,更新最近的點
if d[0] < 0 or dist < d[0]:
# 插入第i個位置的點
self.nearest = np.insert(self.nearest, i, [dist, node], axis=0)
# 刪除最後一個多出來的點
self.nearest = self.nearest[:-1]
break
# 統計距離為-1的個數n
n = list(self.nearest[:, 0]).count(-1)
'''
self.nearest[-n-1, 0]是當前nearest中已經有的最近點中,距離最大的點。
self.nearest[-n-1, 0] > abs(daxis)代表以x為圓心,self.nearest[-n-1, 0]為半徑的圓與axis
相交,說明在左右子樹裏面有比self.nearest[-n-1, 0]更近的點
'''
if self.nearest[-n-1, 0] > abs(daxis):
if daxis < 0:
recurve(node.rchild)
else:
recurve(node.lchild)
recurve(self.KdTree)
# nodeList是最近n個點的
nodeList = self.nearest[:, 1]
# knn是n個點的標簽
knn = [node.label for node in nodeList]
return self.nearest[:, 1], Counter(knn).most_common()[0][0]
這段代碼其實比較好的實現了上面搜索的思想。如果讀者對遞歸的過程想不太清楚,可以畫下圖,或者debug下我完整的代碼GitHub
5. 總結
本文實現了KNN的基礎版和KdTree版本,讀者可以去嘗試下ballTree的版本,理論上效率比KdTree還要好一些。
2. KNN和KdTree算法實現