1. 程式人生 > >python kd樹 搜索

python kd樹 搜索

blog arc 節點 inf fda dex num blank sum

  kd樹就是一種對k維空間中的實例點進行存儲以便對其進行快速檢索的樹形數據結構,可以運用在k近鄰法中,實現快速k近鄰搜索。構造kd樹相當於不斷地用垂直於坐標軸的超平面將k維空間切分,依次選擇坐標軸對空間進行切分,選擇訓練實例點在選定坐標軸上的中位數為切分點。具體kd樹的原理可以參考kd樹的原理。

  代碼是參考《統計學習方法》k近鄰 kd樹的python實現得到

  首先創建一個類,用於表示樹的節點,包括:該節點的值,該節點的切分軸,左子樹,右子樹

class decisionnode:
    def __init__(self,value=None,col=None,rb=None,lb=None):
        self.value
=value self.col=col self.rb=rb self.lb=lb

  切分點為坐標軸上的中值,下面代碼求得一個序列的中值

def median(x):
    n=len(x)
    x=list(x)
    x_order=sorted(x)
    return x_order[n//2],x.index(x_order[n//2])

然後就可以構造一顆kd樹,左子樹小於切分點,右子樹大於切分點

def buildtree(x,j=0):
    rb=[]
    lb=[]
    m,n=x.shape
    if
m==0: return None edge,row=median(x[:,j].copy()) for i in range(m): if x[i][j]>edge: rb.append(i) if x[i][j]<edge: lb.append(i) rb_x=x[rb,:] lb_x=x[lb,:] rightBranch=buildtree(rb_x,(j+1)%n) leftBranch=buildtree(lb_x,(j+1)%n)
return decisionnode(x[row,:],j,rightBranch,leftBranch)

  接下來是樹的搜索過程,可以用下圖表示樹的搜索過程,具體過程可以參考kd樹的原理。

  技術分享圖片

  代碼如下:

#搜索樹:nearestPoint,nearestValue均為全局變量
def traveltree(node,point):
    global nearestPoint,nearestValue
    if node==None: return 
    print(node.value)
    print(---)
    col=node.col
    if point[col]>node.value[col]:
        traveltree(node.rb,point)
    if point[col]<node.value[col]:
        traveltree(node.lb,point)
    dis=dist(node.value,point)
    print(dis)
    if dis<nearestValue:
        nearestPoint=node
        nearestValue=dis
        #print(‘nearestPoint,nearestValue‘ % (nearestPoint,nearestValue))
    if node.rb!=None or node.lb!=None:
        if abs(point[node.col] - node.value[node.col]) < nearestValue:
            if point[node.col]<node.value[node.col]:
                traveltree(node.rb,point)
            if point[node.col]>node.value[node.col]:
                traveltree(node.lb,point)
        
def searchtree(tree,aim):
    global nearestPoint,nearestValue
    #nearestPoint=None
    nearestValue=float(inf)
    traveltree(tree,aim)
    return nearestPoint
        
    
def dist(x1, x2): #歐式距離的計算  
    return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5

完整代碼在此處取

技術分享圖片
 1 import numpy as np
 2 from numpy import array
 3 class decisionnode:
 4     def __init__(self,value=None,col=None,rb=None,lb=None):
 5         self.value=value
 6         self.col=col
 7         self.rb=rb
 8         self.lb=lb
 9         
10 #讀取數據並將數據轉換為矩陣形式        
11 def readdata(filename):    
12     data=open(filename).readlines()
13     x=[]
14     for line in data:
15         line=line.strip().split(\t)
16         x_i=[]
17         for num in line:
18             num=float(num)
19             x_i.append(num)
20         x.append(x_i)
21     x=array(x)
22     return x
23 
24 #求序列的中值    
25 def median(x):
26     n=len(x)
27     x=list(x)
28     x_order=sorted(x)
29     return x_order[n//2],x.index(x_order[n//2])
30 
31 #以j列的中值劃分數據,左小右大,j=節點深度%列數    
32 def buildtree(x,j=0):
33     rb=[]
34     lb=[]
35     m,n=x.shape
36     if m==0: return None
37     edge,row=median(x[:,j].copy())
38     for i in range(m):
39         if x[i][j]>edge: 
40             rb.append(i)
41         if x[i][j]<edge:
42             lb.append(i)
43     rb_x=x[rb,:]
44     lb_x=x[lb,:]
45     rightBranch=buildtree(rb_x,(j+1)%n)
46     leftBranch=buildtree(lb_x,(j+1)%n)
47     return decisionnode(x[row,:],j,rightBranch,leftBranch)
48 
49 #搜索樹:nearestPoint,nearestValue均為全局變量
50 def traveltree(node,point):
51     global nearestPoint,nearestValue
52     if node==None: return 
53     print(node.value)
54     print(---)
55     col=node.col
56     if point[col]>node.value[col]:
57         traveltree(node.rb,point)
58     if point[col]<node.value[col]:
59         traveltree(node.lb,point)
60     dis=dist(node.value,point)
61     print(dis)
62     if dis<nearestValue:
63         nearestPoint=node
64         nearestValue=dis
65         #print(‘nearestPoint,nearestValue‘ % (nearestPoint,nearestValue))
66     if node.rb!=None or node.lb!=None:
67         if abs(point[node.col] - node.value[node.col]) < nearestValue:
68             if point[node.col]<node.value[node.col]:
69                 traveltree(node.rb,point)
70             if point[node.col]>node.value[node.col]:
71                 traveltree(node.lb,point)
72         
73 def searchtree(tree,aim):
74     global nearestPoint,nearestValue
75     #nearestPoint=None
76     nearestValue=float(inf)
77     traveltree(tree,aim)
78     return nearestPoint
79         
80     
81 def dist(x1, x2): #歐式距離的計算  
82     return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5  
View Code

python kd樹 搜索