1. 程式人生 > >Kd-tree原理與實現

Kd-tree原理與實現

資料應用當中,最近鄰查詢是非常重要的功能。不論是資訊檢索,推薦系統,還是資料庫查詢,最近鄰查詢(Nearst Neighbor Search)可謂無處不在。它要實現的是幫助我們找到資料中和查詢最接近的一個或多個數據條目(前者叫NN search, 後者也叫kNN),其實本質上是一樣的,我在這篇部落格中講的Kd-tree主要就是針對這種最近鄰搜尋問題。

1. 基本原理

其實,這種問題本來是很容易解決的,只要設計好了資料相似度的度量方法(有關相似度量的方法詳細可參考我之前的部落格:資料相似性的度量方法總結)計算所有資料與查詢的距離,比較大小即可。但是隨著資料量的增大以及資料維度的提高,這種方法就很難在現實中應用了,因為效率會非常低。解決此類問題的思路基本分為兩類:
(1)通過構建索引,快速排除與查詢相關度不大的資料;
(2)通過降維的方法,對資料條目先降維,再查詢;
前者主要是為了解決資料量過大的問題,比較常見的有我們熟知的二叉搜尋樹,Merkel tree,B-tree,quad-tree等;後者主要是為了解決維度過大的問題,比較常見的方法有我在上一篇部落格中講的LSH:

LSH(Locality Sensitive Hashing)原理與實現

而我們今天要說的Kd-tree就是一種對多維歐式空間分割,從而構建的索引,屬於上面的第一類。

Kd-tree全稱叫做:k dimension tree,這是一種對於多維歐式空間分割構造的的二叉樹,其性質非常類似於二叉搜尋樹。我們先回顧一下二叉搜尋樹,它是一種具有如下特徵的二叉樹:
(1)若它的左子樹不為空,則左子樹上所有結點的值均小於它的根結點的值;
(2)若它的右子樹不為空,則右子樹上所有結點的值均大於它的根結點的值;
(3)它的左、右子樹也分別為二叉搜尋樹;
這個概念是資料結構基礎的東西,應該非常熟悉了,不再贅述,下面給出一棵普通的二叉搜尋樹的圖:
這裡寫圖片描述

如果我們把二叉搜尋樹所對應的資料集看做一個一維空間(因為這個資料集的每一個數據條目都是由一個單一的數值構成的),那麼實際上二叉搜尋樹的分割依據就是數值的大小,這樣的劃分,幫助我們以平均O(lg(n))的時間複雜度搜尋資料。

自然而然,我們會祥這樣一個問題,能不能在多維歐式空間中,構建一棵類似原理的二叉搜尋樹?這也就是我們今天說的Kd-tree.

2. kd-tree的構建

先拋開搜尋演算法怎樣設計這件事不管,我們單純地關心怎樣對多維歐式空間劃分。一維空間簡單,因為每個資料條目只有一個數值,我們直接比較數值大小,就能對這些資料條目劃分,可是在多維空間就存在一個關鍵問題:每個資料條目由多個數值組成,我們怎麼比較?

Kd-tree的原理是這樣的:我們不比較全部的k維資料,而是選擇其中某一個維度比較,根據這個維度進行空間劃分。那接下來,我們需要做的是兩件事:

  • 判斷出在哪一個維度比較,也就是說,我們所要切割的面在哪一個維度上。當然這種切割需要遵循一個基本要求,那就是儘量通過這個維度的切割,使得資料集均分(為二);
  • 判斷以哪個資料條目分依據劃分。上面我們說,要使得資料集均分為二,那當然要選擇一個合適的資料項,充當這個劃分的“點”。

總結一下,就是要選擇一個數據項,以這個資料項的某個維度的值為標準,同一維度的值大於這個值的資料項,劃分為一部分,小於的劃分為另一部分。根據這種劃分來構建二叉樹,就如同二叉搜尋樹那樣。

現在,針對上面的兩件事,我們需要做如下兩個工作:
1. 確定劃分維度:這裡維度的確定需要注意的是儘量要使得這個維度上所有資料項數值的分佈儘可能地有大方差,也就是說,資料在這個維度上儘可能分散。這就好比是我們切東西,如果你切的是一根黃瓜,當讓橫著切要比豎著切更容易。所以我們應該先對所有維度的數值計算方差,選擇方差最大的那個維度;
2. 選擇充當切割標準的資料項:那麼只需要求得這個維度上所有數值的中位數即可;

至此,可以設計出kd-tree的構建演算法了:

  • 對於一個由n維資料構成的資料集,我們首先尋找方差最大的那個維度,設這個維度是d,然後找出在維度d上所有資料項的中位數m,按m劃分資料集,一分為二,記這兩個資料子集為Dl,Dr。建立樹節點,儲存這次劃分的情況(記錄劃分的維度d以及中位數m);
  • Dl,Dr重複進行以上的劃分,並且將新生成的樹節點設定為上一次劃分的左右孩子;
  • 遞迴地進行以上兩步,直到不能再劃分為止(所謂不能劃分是說當前節點中包含的資料項的數量小於了我們事先規定的閾值,不失一般性,我在此篇部落格中預設這個閾值是2,也就是說所有葉子節點包含的資料項不會多於2條),不能再劃分時,將對應的資料儲存至最後的節點中,這些最後的節點也就是葉子節點。

現在可以給出kd-tree的實現程式碼。當然,首先需要設計幾個函式,供演算法呼叫,限於篇幅,這裡只是給出功能說明:

類或函式 作用
class-KdTreeNode kd-tree節點,包含以下6個Attributes
Attribute1-data 樹節點屬性,代表這個節點的資料項,其實是一個列表,如果不是葉子節點,則為空
Attribute2-split 樹節點屬性,代表構建樹時,對這個節點進行分割所依據的資料維度
Attribute3-median 樹節點屬性,代表構建樹時,所有上面split維度上資料的中位數
Attribute4-left 樹節點屬性,代表左孩子
Attribute5-right 樹節點屬性,代表右孩子
Attribute6-parent 樹節點屬性,代表父親節點,作用是在後面的搜尋演算法中用
Attribute7-visited 樹節點屬性,代表此節點是否被演算法回溯遍歷,作用是在後面的搜尋演算法中用
func-getSplit 函式,得到所有維度中方差最大那個維度的序號
func-getMedian 函式,得到要分割的維度的中位數

按照上面這樣設計,就可以實現kd-tree的構建了。我們這裡使用numpy庫,假設現在已經將所有的資料項讀入為一個ndarray型的資料矩陣datamatrixdatamatrix的每一行代表了一個數據項。那麼構建樹演算法的實現程式碼可以如下所示:

import numpy as np

# 樹節點類和其相關方法如下
class KdTreeNode(object):

    def __init__(self, dataMatrix):

        self.data = dataMatrix

        self.left, self.right = None, None
        self.parent = None

        self.split = self.getSplit()
        self.median = self.getMedian()

        self.visited = False

    def getSplit(self):# 取方差最大的維度作為分割維度,程式碼略

    def getMedian(self):# 得到這個分割維度上所有數值的中位數,程式碼略

# 構建kd-tree的函式,helper為其輔助函式,起到遞迴的作用
def buildKdTree(dataMatrix):

    root = KdTreeNode(dataMatrix)

    # there is only one data item in dataMatrix
    if root.data.shape[0] <= 1:
        return root

    helper(root)
    return root


def helper(root):

    if root is None or len(root.data) <= 2:
        return

    # distribute data into left and right
    leftData, rightData = [], []

    # generate left and right child
    for row in list(root.data):
        if row[root.split] <= root.median:
            leftData.append(row)
        else:
            rightData.append(row)

    left = KdTreeNode(np.array(leftData))
    left.parent = root

    right = KdTreeNode(np.array(rightData))
    right.parent = root

    root.data = None
    root.left = left
    root.right = right

    helper(root.left)
    helper(root.right)

我在這裡,借用部落格Kd-Tree演算法原理和開源實現程式碼中的測試樣例:資料集合(2,3), (5,4), (9,6), (4,7), (8,1), (7,2),按照以上演算法原理設計的kd-tree以及劃分情況如以下兩張圖所示:我在這裡直接借用了上面這個連結中部落格的圖,這位博主的文章思路寫的非常清晰。

這裡寫圖片描述
圖中,非葉節點的二元組中,第一個元素表示分割維度(split值),第二個維度表示,取得的中位數(median值)

3. 搜尋演算法

構建好kd-tree後,就可以執行搜尋演算法了。其實,這也是資訊檢索最常見的模式,先構建索引,然後依照索引執行搜尋演算法。當然幾乎所有的搜尋演算法都與其索引是配套的,也就是說,即便是同樣的資料,索引不同,其搜尋演算法就不同,而各有各的技巧。這也是資訊檢索技術最大的魅力之一。

閒話少說,看搜尋演算法。基本思路可分為如下3步:

  1. 依照非葉節點中儲存的分割維度以及中位數資訊,自根節點始,從上向下搜尋,直到到達葉子。遍歷的原則當然是比較分割維度上,查詢值與中位數的大小,設查詢為Q,當前遍歷到的節點為u,則若Q[u.split] > u.median,繼續遍歷u的右子樹,反之,遍歷左子樹。
  2. 遍歷到葉子之後,計算葉子節點中與查詢Q距離最小的資料項與查詢的距離,記為minDis;其後執行“回溯”操作,回溯至當前節點的父節點,判斷以Q為球心,以minDis為半徑的超球面是否與這個父節點的另一個分支所代表的區域有交集(其實,這裡的區域就是一個超矩形,它包含了所有這個節點代表的資料項)。如果沒有,繼續向上一層回溯;如果有,則按照1步繼續執行,探底到葉子節點後,如果此時Q與這個葉子節點中的資料項有更小的距離,則更新minDis
  3. 持續進行以上兩步,直到回溯至根節點,且根節點的兩個分支都被“探測”過為止。

但是這個裡面有一個難點:如何判斷以查詢Q為球心,以當前的minDis為半徑的超球面與樹中,一個非葉節點所代表的超矩形是否相交?
一種簡單的方法是在構建樹的時候直接給每個節點賦值一個超矩形,這個超矩形以一個樹節點屬性的形式存在。一般情況下是給出超矩形的一個最大點和一個最小點。判斷的方法只需要看如下的兩個條件是否都成立即可:

  • Q[u.split] + minDis >= minPoint[u.split]
  • Q[u.split] - minDis >= maxPoint[u.split]

其中,u為查詢當前遍歷到的節點的父節點,minPoint與maxPoint為u所代表的超矩形的最大點和最小點(所謂最大最小點,那二維空間的矩形來說,就是他的右上角的點和左下角的點,分別擁有這個矩形範圍內各個維度上的最大值和最小值)

原因很簡單,因為以Q為球心,以當前這個矩形區域的一個點為球面上一點的一個超球面,一定是經過了當前這個葉子所代表的區域,但是同時它不可能完全覆蓋他的兄弟節點代表的區域。這個道理聽上去有點亂,看下面這個圖就能明白:


圖中,Q1,Q2,Q3是三個查詢點,線段AB是這個矩形空間的分割情況。可見,上面的結論書成立的,同時,我們還可以得到一個觀點:只要|Q[u.split] - u.median|<= minDis那麼就是與其兄弟節點所代表的區域相交。其實這個道理也可以通過數學上的推導得到,如果不能理解的話一試便知。

說道這裡,可以給出搜尋演算法的實現程式碼了:

import math

# 計算兩個多維向量的歐式距離
def dis(item, query):程式碼略

# 回溯,找尋需要處理的下一個節點,下一節點應滿足不曾被演算法回溯遍歷
def findNextNode(cur):程式碼略

# 判斷以查詢為球心,以此時的最小距離minDis為半徑的超球面是否與節點所代表的超矩形相交
def intersect(node, query, radius):程式碼略

# 找到節點的兄弟節點
def getBrother(node):程式碼略


def search(root, query, result, minDis):

    cur = root

    # the root is None
    if not cur:
        return result

    # find leaf
    elif not cur.visited:
        while cur.left and cur.right:
            if query[cur.split] >= cur.median:
                cur = cur.right
            else:
                cur = cur.left

        # update the min dis if it is necessary
        for item in list(cur.data):
            tempDis = dis(item, query)
            if abs(tempDis - minDis) < 1e-9:
                result.append(list(item))
            elif tempDis < minDis:
                minDis = tempDis
                result = [list(item)]


        # update the visited
        cur.visited = True

        # process the next node
        cur = findNextNode(cur)
        if intersect(cur, query, minDis):
            return search(cur, query, result, minDis)
        else:
            cur.visited = True
            nextNode = findNextNode(cur)
            return search(nextNode, query, result, minDis)
    else:
        return result

依照演算法的設計,我們以上面的kd-tree的圖為例,可以看看搜尋演算法遍歷的順序:

  1. 查詢點(8, 3)自根節點起,按照分割維度以及中位數向下遍歷,找到葉子節點(9, 6),此時算得的最小距離為10
  2. 回溯,找到下一個需要處理的節點,也就是(8,1), (7,2)這個點(此時以(8,3)為圓心,以10為半徑的圓與這個點所代表區域相交),資料項 (7,2)與查詢(8, 3)的距離更近,為2,更新最小距離為2
  3. 回溯,此時,非葉節點<2, 2>這個點所在的分支已經被訪問過了,找到下一個需要處理的節點,<2, 4>這個點。不過計算距離發現,這個點所代表的區域並不與此時的圓相交,放棄對這一分支的搜尋;
  4. 回溯至根節點,並且此時根節點的兩個分支都被考慮了,搜尋結束,返回最近鄰(7, 2),最短距離是2

以上就是全部kd-tree的原理以及對應搜尋演算法的實現。內容我大多參考了部落格:Kd-Tree演算法原理和開源實現程式碼
限於篇幅,本篇部落格並未給出全部的詳細程式碼,若要參考,請檢視我的github主頁:KD-tree

不足之處,還望指正。