1. 程式人生 > >kd-tree找最鄰近點 Python實現

kd-tree找最鄰近點 Python實現

kd-tree找最鄰近點 Python實現

基本概念

kd-tree是KNN演算法的一種實現。演算法的基本思想是用多維空間中的例項點,將空間劃分為多塊,成二叉樹形結構。劃分超矩形上的例項點是樹的非葉子節點,而每個超矩形內部的例項點是葉子結點。

超矩形劃分方法

有資料集datalist,其中的資料是Xi,每個Xi由多個特徵值組成。首先將所有資料的Xi[0]找出,取得Xi[0]的中位數center,在樹的根節點中儲存Xi[0] == center的例項點Xi,樹根的左子樹遞迴構造Xi[0] < center的資料集,樹根的右子樹構造X[0] > center的資料集。同時在第二層,劃分特徵變為Xi[1], 劃分特徵隨著樹的深度改變,為 (d - 1) % k , k是特徵的維度,d是此時劃分的樹深度。

python實現

import numpy as np
import matplotlib.pyplot as plot
import math
# kdtree類


class KdTree(object):
    """docstring for KdTree."""

    def __str__(self):
        return '{ nodes:' + str(self.nodes) + ', left:' + str(self.l) + ', right:' + str(self.r) + '}'

    def __init__(self):
        self.
split = None self.l = None self.r = None self.f = None self.nodes = [] def createKdTree(split, datalist, k): if datalist is None or len(datalist) == 0: return None split = split % k node = KdTree() # 求中位數 center = np.sort([a[split] for a in
datalist])[int(len(datalist)/2)] leftData = [a for a in datalist if a[split] < center] rightData = [a for a in datalist if a[split] > center] node.split = split node.nodes = [a for a in datalist if a[split] == center] node.l = createKdTree(split+1, leftData, k) node.r = createKdTree(split+1, rightData, k) # 設定雙親節點 if not node.l is None: node.l.f = node if not node.r is None: node.r.f = node return node # 構建訓練資料 def createData(): k = 2 datalist = None # 構造 三類滿足正態分佈的型別 X_01 = np.random.randn(20) + 2 X_02 = np.random.randn(20) + 3 Y_0 = np.full(20, 1) # 第二類 X_11 = np.random.randn(20) + 2 X_12 = 2*np.random.randn(20) + 10 Y_1 = np.full(20, 2) # 第三類 X_21 = np.random.randn(20) + 8 X_22 = 2*np.random.randn(20) + 10 Y_2 = np.full(20, 3) # 合併 X_1 = np.append(np.append(X_01, X_11), X_21) X_2 = np.append(np.append(X_02, X_12), X_22) Y = np.append(np.append(Y_0, Y_1), Y_2) return (list(zip(X_1, X_2, Y)), k) # 預測,返回測試集X中每個例項屬於哪一類 def predict(head, X): res = [] for x in X: res.append(guessX(head, x)) return res def guessX(node, x): if x[node.split] < node.nodes[0][node.split] and not node.l is None: return guessX(node.l, x) elif x[node.split] > node.nodes[0][node.split] and not node.r is None: return guessX(node.r, x) else: return neast(node, x) def getDis(X1, X2): l = len(X2) sum = 0 for i in range(l): sum += (X1[i] - X2[i])**2 return sum def get2Min(nodes, x, minDis, minX): for n in nodes: d = getDis(n, x) if d < minDis: minDis = d minX = n return (minDis, minX) def neast(node, x): nodes = node.nodes dis = [] minDis = 10000 minX = 0 minDis, minX = get2Min(nodes, x, minDis, minX) if node.f is None: return minX[-1] return findY(node.f, x, minDis, minX, 0 if node.f.l == node else 1) def findY(node, x, minDis, minX, dire): if math.fabs(node.nodes[0][node.split] - x[node.split]) > minDis: return minX[-1] minDis, minX = get2Min(node.nodes, x, minDis, minX) minDis, minX = reachAll(node.r if dire == 0 else node.l, x, minDis, minX) if node.f is None: # print(minX) return minX[-1] return findY(node.f, x, minDis, minX, 0 if node.f.l == node else 1) def reachAll(node, x, minDis, minX): if node is None: return (minDis, minX) minDis, minX = get2Min(node.nodes, x, minDis, minX) minDis, minX = reachAll(node.l, x, minDis, minX) minDis, minX = reachAll(node.r, x, minDis, minX) return (minDis, minX) if __name__ == '__main__': # 特徵 + 類別 datalist, k = createData() head = createKdTree(0, datalist, k) # 測試資料 X_1test = np.random.randn(5) + 2 X_2test = 2*np.random.randn(5) + 10 X = list(zip(X_1test, X_2test)) # 結果以list形式返回 res = predict(head, X) print(res)