三維點雲處理系列---- 二叉樹實現K-NN Radius-NN Search
阿新 • • 發佈:2022-05-07
二叉樹的搜尋方法
一種是遞迴,一種是迴圈判斷,本質區別並不大
#遞迴搜尋 def search_recursively(root,key): #1NN 搜尋 ,遞迴法 if root is None or root.key == key: return root if key < root.key: return search_recursively(root.left,key) elif key > root.key: return search_recursively(root.right,key) #迴圈判斷搜尋 def search_iterative(root, key): #1NN 搜尋 ,迴圈判斷 current_node = root while current_node is not None: if current_node.key == key: return current_node elif key < current_node.key: current_node = current_node.left elif key > current_node.key: current_node = current_node.right return current_node
二叉樹的優勢,減少搜尋的複雜度
實際執行結果
Search in 100 points, takes 7 comparison only #使用二叉樹僅僅比較7次
Complexity is around O(log2(n)), n is number of
database points, if tree is balanced #假設二叉樹是平衡的,複雜度為log2(n),n為二叉樹的深度
Worst O(N) #最壞結果,比較100次
kNN Search: index - distance 24 - 0.00 85 - 1.00 42 - 1.00 12 - 2.00 86 - 2.00 In total 8 comparison operations. Radius NN Search: index - distance 24 - 0.00 85 - 1.00 42 - 1.00 12 - 2.00 86 - 2.00 In total 5 neighbors within 2.000000. There are 8 comparison operations.
二叉樹的三種遍歷方式
#二叉樹的三種應用 def inorder(root): # Inorder (Left, Root, Right) if root is not None: inorder(root.left) print(root) inorder(root.right) def preorder(root): # Preorder (Root, Left, Right) if root is not None: print(root) preorder(root.left) preorder(root.right) def postorder(root): # Postorder (Left, Right, Root) if root is not None: postorder(root.left) postorder(root.right) print(root)
1NN搜尋過程
KNN search
worst Distance for KNN
具體思路:
1.先建立一個能容納需要的臨近點結果的list
2.將暫時的KNN result 進行sorted
3.最大worst_dist 的點在KNN result list的最後(隨時被替代)
4.根據worst_list的不斷更新,動態修改KNN result裡的結果
Radius NN search
方法思路和KNN演算法差不多,區別在於
Worst distance is fixed.(Radius NN search預先設定檢測radius,在radius裡進行點的篩選)
KNN search VS Radius NN search
完整程式碼
bst.py
import random import math import numpy as np from result_set import KNNResultSet,RadiusNNResultSet class Node: #節點,每一個數都是一個分支節點 def __init__(self,key,value=-1): self.left = None self.right = None self.key =key self.value = value #value可以用作儲存其他數值,譬如點原來的序號 def __str__(self): return "key: %s, value: %s" % (str(self.key), str(self.value)) def insert(root,key,value=-1): #構建二叉樹 if root is None: root = Node(key,value) #賦初值 else: if key < root.key: root.left = insert(root.left,key,value) #小數放左邊 elif key > root.key: root.right = insert(root.right,key,value) #大數放右邊 else: # don't insert if key already exist in the tree pass return root #二叉樹的三種應用 def inorder(root): # Inorder (Left, Root, Right) if root is not None: inorder(root.left) print(root) inorder(root.right) def preorder(root): # Preorder (Root, Left, Right) if root is not None: print(root) preorder(root.left) preorder(root.right) def postorder(root): # Postorder (Left, Right, Root) if root is not None: postorder(root.left) postorder(root.right) print(root) def knn_search(root:Node,result_set:KNNResultSet,key): if root is None: return False # compare the root itself result_set.add_point(math.fabs(root.key - key),root.value) #計算worst_dist ,並把當前root.value(index二叉樹)裡的值加入到resut_set 中 if result_set.worstDist() == 0: return True if root.key >= key: # iterate left branch first if knn_search(root.left, result_set, key): return True elif math.fabs(root.key-key) < result_set.worstDist(): return knn_search(root.right, result_set, key) return False else: # iterate right branch first if knn_search(root.right, result_set, key): return True elif math.fabs(root.key-key) < result_set.worstDist(): return knn_search(root.left, result_set, key) return False def radius_search(root: Node, result_set: RadiusNNResultSet, key): if root is None: return False # compare the root itself result_set.add_point(math.fabs(root.key - key), root.value) if root.key >= key: # iterate left branch first if radius_search(root.left, result_set, key): return True elif math.fabs(root.key-key) < result_set.worstDist(): return radius_search(root.right, result_set, key) return False else: # iterate right branch first if radius_search(root.right, result_set, key): return True elif math.fabs(root.key-key) < result_set.worstDist(): return radius_search(root.left, result_set, key) return False def search_recursively(root,key): #1NN 搜尋 ,遞迴法 if root is None or root.key == key: return root if key < root.key: return search_recursively(root.left,key) elif key > root.key: return search_recursively(root.right,key) def search_iterative(root, key): #1NN 搜尋 ,迴圈判斷 current_node = root while current_node is not None: if current_node.key == key: return current_node elif key < current_node.key: current_node = current_node.left elif key > current_node.key: current_node = current_node.right return current_node def main(): # Data generation db_size = 100 k = 5 #搜尋5個點 radius = 2.0 data = np.random.permutation(db_size).tolist() #random.permutation 隨機排列一個數組 root =None for i,point in enumerate(data): root = insert(root,point,i) query_key = 6 result_set = KNNResultSet(capacity=k) knn_search(root, result_set, query_key) print('kNN Search:') print('index - distance') print(result_set) result_set = RadiusNNResultSet(radius=radius) radius_search(root, result_set, query_key) print('Radius NN Search:') print('index - distance') print(result_set) # print("inorder") # inorder(root) # print("preorder") # preorder(root) # print("postorder") # postorder(root) # node = search_recursive(root, 2) # print(node) # # node = search_iterative(root, 2) # print(node) if __name__ == '__main__': main()
result_set.py (KNN Radius NN search config fcn)
import copy class DistIndex: def __init__(self, distance, index): self.distance = distance self.index = index def __lt__(self, other): return self.distance < other.distance class KNNResultSet: def __init__(self, capacity): self.capacity = capacity self.count = 0 self.worst_dist = 1e10 self.dist_index_list = [] for i in range(capacity): self.dist_index_list.append(DistIndex(self.worst_dist, 0)) self.comparison_counter = 0 def size(self): return self.count def full(self): return self.count == self.capacity def worstDist(self): return self.worst_dist def add_point(self, dist, index): self.comparison_counter += 1 if dist > self.worst_dist: return if self.count < self.capacity: self.count += 1 i = self.count - 1 while i > 0: if self.dist_index_list[i - 1].distance > dist: self.dist_index_list[i] = copy.deepcopy(self.dist_index_list[i - 1]) i -= 1 else: break self.dist_index_list[i].distance = dist self.dist_index_list[i].index = index self.worst_dist = self.dist_index_list[self.capacity - 1].distance def __str__(self): output = '' for i, dist_index in enumerate(self.dist_index_list): output += '%d - %.2f\n' % (dist_index.index, dist_index.distance) output += 'In total %d comparison operations.' % self.comparison_counter return output class RadiusNNResultSet: def __init__(self, radius): self.radius = radius self.count = 0 self.worst_dist = radius self.dist_index_list = [] self.comparison_counter = 0 def size(self): return self.count def worstDist(self): return self.radius def add_point(self, dist, index): self.comparison_counter += 1 if dist > self.radius: return self.count += 1 self.dist_index_list.append(DistIndex(dist, index)) def __str__(self): self.dist_index_list.sort() output = '' for i, dist_index in enumerate(self.dist_index_list): output += '%d - %.2f\n' % (dist_index.index, dist_index.distance) output += 'In total %d neighbors within %f.\nThere are %d comparison operations.' \ % (self.count, self.radius, self.comparison_counter) return output