KD樹的python實踐
簡單的KNN演算法在為每個資料點預測類別時都需要遍歷整個訓練資料集來求解距離,這樣的做法在訓練資料集特別大的時候並不高效,一種改進的方法就是使用kd樹來儲存訓練資料集,這樣可以使KNN分類器更高效。
KD樹的主要思想跟二叉樹類似,我們先來回憶一下二叉樹的結構,二叉樹中每個節點可以看成是一個數,當前節點總是比左子樹中每個節點大,比右子樹中每個節點小。而KD樹中每個節點是一個向量(也可能是多個向量),和二叉樹總是按照數的大小劃分不同的是,KD樹每層需要選定向量中的某一維,然後根據這一維按左小右大的方式劃分資料。在構建KD樹時,關鍵需要解決2個問題:(1)選擇向量的哪一維進行劃分(2)如何劃分資料。第一個問題簡單的解決方法可以是選擇隨機選擇某一維或按順序選擇,但是更好的方法應該是在資料比較分散的那一維進行劃分(分散的程度可以根據方差來衡量)。好的劃分方法可以使構建的樹比較平衡,可以每次選擇中位數來進行劃分,這樣問題2也得到了解決。下面是建立KD樹的Python程式碼:
def build_tree(data, dim, depth):
"""
建立KD樹
Parameters
----------
data:numpy.array
需要建樹的資料集
dim:int
資料集特徵的維數
depth:int
當前樹的深度
Returns
-------
tree_node:tree_node namedtuple
樹的跟節點
"""
size = data.shape[0]
if size == 0 :
return None
# 確定本層劃分參照的特徵
split_dim = depth % dim
mid = size / 2
# 按照參照的特徵劃分資料集
r_indx = np.argpartition(data[:, split_dim], mid)
data = data[r_indx, :]
left = data[0: mid]
right = data[mid + 1: size]
mid_data = data[mid]
# 分別遞迴建立左右子樹
left = build_tree(left, dim, depth + 1 )
right = build_tree(right, dim, depth + 1)
# 返回樹的根節點
return Tree_Node(left=left,
right=right,
data=mid_data,
split_dim=split_dim)
對於一個新來的資料點x,我們需要查詢KD樹中距離它最近的節點。KD樹的查詢演算法還是和二叉樹查詢的演算法類似,但是因為KD樹每次是按照某一特定的維來劃分,所以當從跟節點沿著邊查詢到葉節點時候並不能保證當前的葉節點就離x最近,我們還需要回溯並在每個父節點上判斷另一個未查詢的子樹是否有可能存在離x更近的點(如何確定的方法我們可以思考二維的時候,以x為原點,當前最小的距離為半徑畫園,看是否與劃分的直線相交,相交則另一個子樹中可能存在更近的點),如果存在就進入子樹查詢。
當我們需要查詢K個距離x最近的節點時,我們只需要維護一個長度為K的優先佇列保持當前距離x最近的K個點。在回溯時,每次都使用第K短距離來判斷另一個子節點中是否存在更近的節點即可。下面是具體實現的Python程式碼:
def search_n(cur_node, data, queue, k):
"""
查詢K近鄰,最後queue中的k各值就是k近鄰
Parameters
----------
cur_node:tree_node namedtuple
當前樹的跟節點
data:numpy.array
資料
queue:Queue.PriorityQueue
記錄當前k個近鄰,距離大的先輸出
k:int
查詢的近鄰個數
"""
# 當前節點為空,直接返回上層節點
if cur_node is None:
return None
if type(data) is not np.array:
data = np.asarray(data)
cur_data = cur_node.data
# 得到左右子節點
left = cur_node.left
right = cur_node.right
# 計算當前節點與資料點的距離
distance = np.sum((data - cur_data) ** 2) ** .5
cur_split_dim = cur_node.split_dim
flag = False # 標記在回溯時是否需要進入另一個子樹查詢
# 根據參照的特徵來判斷是先進入左子樹還是右子樹
if data[cur_split_dim] > cur_data[cur_split_dim]:
tmp = right
right = left
left = tmp
# 進入子樹查詢
search_n(left, data, queue, k)
# 下面是回溯過程
# 當佇列中沒有k個近鄰時,直接將當前節點入隊,並進入另一個子樹開始查詢
if len(queue) < k:
neg_distance = -1 * distance
heapq.heappush(queue, (neg_distance, cur_node))
flag = True
else:
# 得到當前距離資料點第K遠的節點
top_neg_distance, top_node = heapq.heappop(queue)
# 如果當前節點與資料點的距離更小,則更新佇列(當前節點入隊,原第k遠的節點出隊)
if - 1 * top_neg_distance > distance:
top_neg_distance, top_node = -1 * distance, cur_node
heapq.heappush(queue, (top_neg_distance, top_node))
# 判斷另一個子樹內是否可能存在跟資料點的距離比當前第K遠的距離更小的節點
top_neg_distance, top_node = heapq.heappop(queue)
if abs(data[cur_split_dim] - cur_data[cur_split_dim]) < -1 * top_neg_distance:
flag = True
heapq.heappush(queue, (top_neg_distance, top_node))
# 進入另一個子樹搜尋
if flag:
search_n(right, data, queue, k)
以上就是KD樹的Python實踐的全部內容,由於本人剛接觸python不久,可能實現上並不優雅,也可能在演算法理解上存在偏差,如果有任何的錯誤或不足,希望各位賜教。