python kd樹 搜索
阿新 • • 發佈:2018-02-09
blog arc 節點 inf fda dex num blank sum
kd樹就是一種對k維空間中的實例點進行存儲以便對其進行快速檢索的樹形數據結構,可以運用在k近鄰法中,實現快速k近鄰搜索。構造kd樹相當於不斷地用垂直於坐標軸的超平面將k維空間切分,依次選擇坐標軸對空間進行切分,選擇訓練實例點在選定坐標軸上的中位數為切分點。具體kd樹的原理可以參考kd樹的原理。
代碼是參考《統計學習方法》k近鄰 kd樹的python實現得到
首先創建一個類,用於表示樹的節點,包括:該節點的值,該節點的切分軸,左子樹,右子樹
class decisionnode: def __init__(self,value=None,col=None,rb=None,lb=None): self.value=value self.col=col self.rb=rb self.lb=lb
切分點為坐標軸上的中值,下面代碼求得一個序列的中值
def median(x): n=len(x) x=list(x) x_order=sorted(x) return x_order[n//2],x.index(x_order[n//2])
然後就可以構造一顆kd樹,左子樹小於切分點,右子樹大於切分點
def buildtree(x,j=0): rb=[] lb=[] m,n=x.shape ifm==0: return None edge,row=median(x[:,j].copy()) for i in range(m): if x[i][j]>edge: rb.append(i) if x[i][j]<edge: lb.append(i) rb_x=x[rb,:] lb_x=x[lb,:] rightBranch=buildtree(rb_x,(j+1)%n) leftBranch=buildtree(lb_x,(j+1)%n)return decisionnode(x[row,:],j,rightBranch,leftBranch)
接下來是樹的搜索過程,可以用下圖表示樹的搜索過程,具體過程可以參考kd樹的原理。
代碼如下:
#搜索樹:nearestPoint,nearestValue均為全局變量 def traveltree(node,point): global nearestPoint,nearestValue if node==None: return print(node.value) print(‘---‘) col=node.col if point[col]>node.value[col]: traveltree(node.rb,point) if point[col]<node.value[col]: traveltree(node.lb,point) dis=dist(node.value,point) print(dis) if dis<nearestValue: nearestPoint=node nearestValue=dis #print(‘nearestPoint,nearestValue‘ % (nearestPoint,nearestValue)) if node.rb!=None or node.lb!=None: if abs(point[node.col] - node.value[node.col]) < nearestValue: if point[node.col]<node.value[node.col]: traveltree(node.rb,point) if point[node.col]>node.value[node.col]: traveltree(node.lb,point) def searchtree(tree,aim): global nearestPoint,nearestValue #nearestPoint=None nearestValue=float(‘inf‘) traveltree(tree,aim) return nearestPoint def dist(x1, x2): #歐式距離的計算 return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5
完整代碼在此處取
1 import numpy as np 2 from numpy import array 3 class decisionnode: 4 def __init__(self,value=None,col=None,rb=None,lb=None): 5 self.value=value 6 self.col=col 7 self.rb=rb 8 self.lb=lb 9 10 #讀取數據並將數據轉換為矩陣形式 11 def readdata(filename): 12 data=open(filename).readlines() 13 x=[] 14 for line in data: 15 line=line.strip().split(‘\t‘) 16 x_i=[] 17 for num in line: 18 num=float(num) 19 x_i.append(num) 20 x.append(x_i) 21 x=array(x) 22 return x 23 24 #求序列的中值 25 def median(x): 26 n=len(x) 27 x=list(x) 28 x_order=sorted(x) 29 return x_order[n//2],x.index(x_order[n//2]) 30 31 #以j列的中值劃分數據,左小右大,j=節點深度%列數 32 def buildtree(x,j=0): 33 rb=[] 34 lb=[] 35 m,n=x.shape 36 if m==0: return None 37 edge,row=median(x[:,j].copy()) 38 for i in range(m): 39 if x[i][j]>edge: 40 rb.append(i) 41 if x[i][j]<edge: 42 lb.append(i) 43 rb_x=x[rb,:] 44 lb_x=x[lb,:] 45 rightBranch=buildtree(rb_x,(j+1)%n) 46 leftBranch=buildtree(lb_x,(j+1)%n) 47 return decisionnode(x[row,:],j,rightBranch,leftBranch) 48 49 #搜索樹:nearestPoint,nearestValue均為全局變量 50 def traveltree(node,point): 51 global nearestPoint,nearestValue 52 if node==None: return 53 print(node.value) 54 print(‘---‘) 55 col=node.col 56 if point[col]>node.value[col]: 57 traveltree(node.rb,point) 58 if point[col]<node.value[col]: 59 traveltree(node.lb,point) 60 dis=dist(node.value,point) 61 print(dis) 62 if dis<nearestValue: 63 nearestPoint=node 64 nearestValue=dis 65 #print(‘nearestPoint,nearestValue‘ % (nearestPoint,nearestValue)) 66 if node.rb!=None or node.lb!=None: 67 if abs(point[node.col] - node.value[node.col]) < nearestValue: 68 if point[node.col]<node.value[node.col]: 69 traveltree(node.rb,point) 70 if point[node.col]>node.value[node.col]: 71 traveltree(node.lb,point) 72 73 def searchtree(tree,aim): 74 global nearestPoint,nearestValue 75 #nearestPoint=None 76 nearestValue=float(‘inf‘) 77 traveltree(tree,aim) 78 return nearestPoint 79 80 81 def dist(x1, x2): #歐式距離的計算 82 return ((np.array(x1) - np.array(x2)) ** 2).sum() ** 0.5View Code
python kd樹 搜索