1. 程式人生 > 其它 >資料結構和演算法——kd樹

資料結構和演算法——kd樹

一、K-近鄰演算法

K-近鄰演算法是一種典型的無參監督學習演算法,對於一個監督學習任務來說,其mm個訓練樣本為:

{(X(1),y(1)),(X(2),y(2)),⋯,(X(m),y(m))}

left { left ( X^{left ( 1 right )},y^{left ( 1 right )} right ),left ( X^{left ( 2 right )},y^{left ( 2 right )} right ),cdots ,left ( X^{left ( m right )},y^{left ( m right )} right ) right }

在K-近鄰演算法中,無需利用訓練樣本學習出統一的模型,對於一個新的樣本,如XX,通過比較樣本XX與mm個訓練樣本的相似度,選擇出kk個最相似的樣本,並以這kk個樣本的標籤作為樣本XX的標籤。

在如上的描述中,樣本XX需要分別與mm個訓練樣本計算相似度,通常,使用的相似度的計算方法為歐式距離,即對於樣本Xi={xi,1,xi,2,⋯,xi,n}X_i=left { x_{i,1},x_{i,2},cdots ,x_{i,n} right }和樣本Xj={xj,1,xj,2,⋯,xj,n}X_j=left { x_{j,1},x_{j,2},cdots ,x_{j,n} right },其兩者之間的相似度為:

S=∑t=1n(xi,t−xj,t)2−−−−−−−−−−−−−√

S=sqrt{sum_{t=1}^{n}left ( x_{i,t}-x_{j,t} right )^2}

對於K-近鄰演算法的具體過程,可以參見博文簡單易學的機器學習演算法——K-近鄰演算法

在K-近鄰演算法的計算過程中,通過暴力的對每一對樣本計算其相似度是非常好費時間的,那麼是否存在一種方法,能夠加快計算的速度?kd樹便是其中的一種方法。

二、kd樹

kd樹是一種對kk維空間中的例項點進行儲存以便對其進行快速檢索的樹形資料結構,且kd樹是一種二叉樹,表示對kk維空間的一個劃分。

1、二叉排序樹

在資料結構中,二叉排序樹又稱二叉查詢樹或者二叉搜尋樹。其定義為:二叉排序樹,或者是一棵空樹,或者是具有下列性質的二叉樹:

  • 若它的左子樹不空,則左子樹上所有結點的值均小於它的根結點的值;
  • 若它的右子樹不空,則右子樹上所有結點的值均大於它的根結點的值;
  • 它的左、右子樹也分別為二叉排序樹。

一個典型的二叉排序樹的例子如下圖所示:

在二叉排序樹中,若以中序遍歷,則得到的是按照值大小排序的結果,即1->3->4->6->7->8->10->13->14。

如果需要檢索7,則從根結點開始:

  • 7<87<8->左子樹
  • 7>37>3->右子樹
  • 7>67>6->右子樹
  • 7=77=7->查詢結束

但是,對於二叉排序樹的建立,若構建二叉排序樹的順序為基本有序時,如按照1->3->4->6->7->8->10->13->14構建二叉排序樹,會得到如下的結果:

這樣的話,檢索效率會下降,為了避免這樣的情況的出現,會對二叉樹設定一些條件,如平衡二叉樹。對於二叉排序樹的更多內容,可以參見資料結構和演算法——二叉排序樹

2、kd樹的概念

kd樹與二叉排序樹的基本思想類似,與二叉排序樹不同的是,在kd樹中,每一個節點表示的是一個樣本,通過選擇樣本中的某一維特徵,將樣本劃分到不同的節點中,如對於樣本{(7,2),(5,4),(9,6)}left { left ( 7,2 right ),left ( 5,4 right ),left ( 9,6 right ) right }, 考慮資料的第一維,首先,根節點為{(7,2)}left { left ( 7,2 right )right },由於樣本{(5,4)}left { left ( 5,4 right )right }的第一維55小於77,因此,樣本{(5,4)}left { left ( 5,4 right )right }在根節點的左子樹上,同理,樣本{(9,6)}left { left ( 9,6 right )right }在根節點的右子樹上。通過第一維可以構建如下的二叉樹模型:

在kd樹的基本操作中,主要包括kd樹的建立和kd樹的檢索兩個部分。

3、kd樹的建立

構造kd樹相當於不斷地用垂直於座標軸的超平面將kk維空間切分成一系列的kk維超矩陣區域。選擇劃分節點的方法主要有兩種:

  • 順序選擇,即按照資料的順序依次在kd樹中插入節點;
  • 選擇待劃分維數的中位數為劃分的節點。在kd樹的構建過程中,為了防止出現只有左子樹或者只有右子樹的情況出現,通常對於每一個節點,選擇樣本中的中位數作為切分點。這樣構建出來的kd樹時平衡的。

在李航的《統計機器學習》P41中有提到:平衡的kd樹搜尋時的效率未必是最優的。

在構建kd樹的過程中,也可以根據插入資料的順序構建kd樹,以二維資料集為例,其資料的順序依次為:

{(3,6),(7,5),(3,1),(6,2),(9,1),(2,7)}

left { left ( 3,6 right ),left ( 7,5 right ),left ( 3,1 right ),left ( 6,2 right ),left ( 9,1 right ),left ( 2,7 right ) right }

對於如上的二維資料集,構建kd樹:

  • 選擇一維最為切分的維度,如選擇第00維,第一個數為(3,6)left ( 3,6 right ),其第00維的值為33,以(3,6)left ( 3,6 right )作為kd樹的根結點,若第00維的值大於33為右子樹,否則插入到左子樹中;
  • 對後續的節點依次判斷,如(7,5)left ( 7,5 right ),選擇第00維,其值為77,大於33,插入到根結點的右子樹中,設定其維數為除了第00維以外的任一維。。。

按照如上的過程,我們劃分出來的kd樹如下圖所示:

此時,將樣本按照特徵空間劃分如下圖所示:

由以上的計算過程可以看出對於樹中節點,需要有資料項,當前節點的比較維度,指向左子樹的指標和指向右子樹的指標,可以設定其結構如下:

#define MAX_LEN 1024

typedef struct KDtree{
        double data[MAX_LEN]; // 資料
        int dim; // 選擇的維度
        struct KDtree *left; // 左子樹
        struct KDtree *right; // 右子樹
}kdtree_node;

構造kd樹的函式宣告為:

int kdtree_insert(kdtree_node *&tree_node, double *data, int layer, int dim);

函式的具體實現如下:

// 遞迴構建kd樹,通過節點所在的層數控制選擇的維度
int kdtree_insert(kdtree_node * &tree_node, double *data, int layer, int dim){
        // 空樹
        if (NULL == tree_node){
                // 申請空間
                tree_node = (kdtree_node *)malloc(sizeof(kdtree_node));
                if (NULL == tree_node) return 1;

                //插入元素
                for (int i = 0; i < dim; i ++){
                        (tree_node->data)[i] = data[i];
                }
                tree_node->dim = layer % (dim);
                tree_node->left = NULL;
                tree_node->right = NULL;

                return 0;
        }

        // 插入左子樹
        if (data[tree_node->dim] <= (tree_node->data)[tree_node->dim]){
                return kdtree_insert(tree_node->left, data, ++layer, dim);
        }

        // 插入右子樹
        return kdtree_insert(tree_node->right, data, ++layer, dim);
}

當構建好了kd樹後,需要對kd樹進行遍歷,在這裡,實現了兩種kd樹的遍歷方法:

  • 先序遍歷
  • 中序遍歷

對於先序遍歷,其函式的宣告為:

void kdtree_print(kdtree_node *tree, int dim);

函式的具體實現為:

void kdtree_print(kdtree_node *tree, int dim){
        if (tree != NULL){
                fprintf(stderr, "dim:%dn", tree->dim);
                for (int i = 0; i < dim; i++){
                        fprintf(stderr, "%lft", (tree->data)[i]);
                }
                fprintf(stderr, "n");
                kdtree_print(tree->left, dim);
                kdtree_print(tree->right, dim);
        }
}

對於中序遍歷,其函式的宣告為:

void kdtree_print_in(kdtree_node *tree, int dim);

函式的具體實現為:

void kdtree_print_in(kdtree_node *tree, int dim){
        if (tree != NULL){
                kdtree_print_in(tree->left, dim);
                fprintf(stderr, "dim:%dn", tree->dim);
                for (int i = 0; i < dim; i++){
                        fprintf(stderr, "%lft", (tree->data)[i]);
                }
                fprintf(stderr, "n");
                kdtree_print_in(tree->right, dim);
        }
}

4、kd樹的檢索

與二叉排序樹一樣,在kd樹中,將樣本劃分到不同的空間中,在查詢的過程中,由於查詢在某些情況下僅需查詢部分的空間,這為查詢的過程節省了對大部分資料點的搜尋的時間,對於kd樹的檢索,其具體過程為:

  • 從根節點開始,將待檢索的樣本劃分到對應的區域中(在kd樹形結構中,從根節點開始查詢,直到葉子節點,將這樣的查詢序列儲存到棧中)
  • 以棧頂元素與待檢索的樣本之間的距離作為最短距離min_distance
  • 執行出棧操作:
    • 向上回溯,查詢到父節點,若父節點與待檢索樣本之間的距離小於當前的最短距離min_distance,則替換當前的最短距離min_distance
    • 以待檢索的樣本為圓心(二維,高維情況下是球心),以min_distance為半徑畫圓,若圓與父節點所在的平面相割,則需要將父節點的另一棵子樹進棧,重新執行以上的出棧操作
  • 直到棧為空

以查詢(6,3)left ( 6,3 right )為例,首先,我們需要找到待查詢的樣本所在的搜尋空間,搜尋空間如下圖中的黑色區域所示:

其對應的進棧序列為:{(3,6),(7,5),(6,2)}left { left ( 3,6 right ),left ( 7,5 right ),left ( 6,2 right ) right }。

此時,以到(6,2)left ( 6,2 right )之間的距離為最短距離,最短距離min_distance為1,對棧頂元素出棧,此時棧中的序列為:{(3,6),(7,5)}left { left ( 3,6 right ),left ( 7,5 right ) right }。以待檢索樣本(6,3)left ( 6,3 right )為圓心,1為半徑畫圓,圓與(6,2)left ( 6,2 right )所在平面相割,如下圖所示:

此時,需要檢索以(6,2)left ( 6,2 right )為根節點的另外一棵子樹,即需要將(9,1)left ( 9,1 right )進棧,此時,棧中的序列為:{(3,6),(7,5),(9,1)}left { left ( 3,6 right ),left ( 7,5 right ),left ( 9,1 right ) right }。

注意:若需要進棧的子樹中有很多節點,則根據需要比較的元素的大小,將直到葉節點的所有節點都進棧,這一點在很多地方都寫得不清楚。

按照上述的步驟,再執行出棧的操作,直到棧為空。

檢索過程的函式宣告為:

void search_nearest(kdtree_node *tree, double *data_search, int dim, double *result);

函式的具體實現為:

void search_nearest(kdtree_node *tree, double *data_search, int dim, double *result){
    // 一直找到葉子節點
    fprintf(stderr, "nstart searching....n");
    stack<kdtree_node *> st;

    kdtree_node *p = tree;

    while (p->left != NULL || p->right != NULL){
        st.push(p);// 將p壓棧
        if (data_search[p->dim] <= (p->data)[p->dim]){// 選擇左子樹
            // 判斷左子樹是否為空
            if (p->left == NULL) break;
            p = p->left;
        }else{ // 選擇右子樹
            if (p->right == NULL) break;
            p = p->right;
        }
    }


    // 現在與棧中的資料進行對比
    double min_distance = distance(data_search, p->data, dim);// 與根結點之間的距離
    fprintf(stderr, "init: %lfn", min_distance);
    copy2result(p->data, result, dim);
    // 列印最優值
    for (int i = 0; i < dim; i++){
                fprintf(stderr, "%lft", result[i]);
        }
        fprintf(stderr, "n");

    double d = 0;
    while (st.size() > 0){
        kdtree_node *q = st.top();// 找到棧頂元素
        st.pop(); // 出棧

        // 判斷與父節點之間的距離
        d = distance(data_search, q->data, dim);

        if (d <= min_distance){
            min_distance = d;
            copy2result(q->data, result, dim);
        }

        // 判斷與分隔面是否相交
        double d_line = distance_except_dim(data_search, q->data, q->dim); // 到平面之間的距離
        if (d_line < min_distance){ // 相交
            // 如果本來在右子樹,現在查詢左子樹
            // 如果本來在左子樹,現在查詢右子樹
            if (data_search[q->dim] > (q->data)[q->dim]){
                // 選擇左子樹
                if (q->left != NULL) q = q->left;
                else q = NULL;
            }else{
                // 選擇右子樹
                if (q->right != NULL) q = q->right;
                else q = NULL;
            }
            if (q != NULL){
                while (q->left != NULL || q->right != NULL){
                    st.push(q);
                    if (data_search[q->dim] <= (q->data)[q->dim]){
                        if (q->left == NULL) break;
                        q = q->left;
                    }else{
                        if (q->right == NULL) break;
                        q = q->right;
                    }
                }
                if (q->left == NULL && q->right == NULL) st.push(q);
            }
        }

    }
}

在函式的實現中,需要用到的函式為:

  • 兩個樣本之間的距離
double distance(double *a, double *b, int dim){
    double d = 0.0;
    for (int i = 0; i < dim; i ++){
        d += (a[i] - b[i]) * (a[i] - b[i]);
    }
    return d;
}
  • 待檢索的樣本到平面之間的距離
double distance_except_dim(double *a, double *b, int except_dim){
    double d = (a[except_dim] - b[except_dim]) * (a[except_dim] - b[except_dim]);
    return d;
}
  • 複製最優的結果
void copy2result(double *a, double *result, int dim){
    for (int i = 0; i < dim; i ++){
        result[i] = a[i];
    }
}

三、測試

利用如上的測試集,我們構建kd樹,並在kd樹中查詢(6,3)left ( 6,3 right ),測試程式碼如下:

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "kdtree.h"

// 解析特徵
int parse_feature(char *p, double *fea, int *dim){
        // 解析特徵
        char *q = p;
        int i = 0;
        while ((q = strchr(p, 't')) != NULL){
                *q = 0;
                fea[i] = atof(p);
                //fprintf(stderr, "atof(p):%lfn", atof(p));
                p = q + 1;
                //r = r + 1;
                i += 1;
        }

        // 解析最後一個
        fea[i] = atof(p);

        *dim = i + 1;
        //fprintf(stderr, "atof(p):%lfn", atof(p));

        //fprintf(stderr, "fea:%lft%lfn", fea[0], fea[1]);
}

int main(){
        kdtree_node *tree_node = NULL;
        // 從檔案中讀入資料
        FILE *fp = fopen("data.txt", "r");
        char feature[MAX_LEN];
        double data[MAX_LEN];
        int data_dim = 0; // 資料的維數

        double data_search[2] = {6.0, 3.0};

        while (fgets(feature, MAX_LEN, fp)){
                fprintf(stderr, "%s", feature);
                parse_feature(feature, data, &data_dim);
                fprintf(stderr, "distance: %lfn", distance(data, data_search, data_dim));
                // 插入到kd樹中
                kdtree_insert(tree_node, data, 0, data_dim);
        }
        fclose(fp);
        fprintf(stderr, "dim:%dn", data_dim);
        fprintf(stderr, "insert_okn");
        // test

        kdtree_print(tree_node, data_dim);
        printf("n");
        kdtree_print_in(tree_node, data_dim);


        double result[2];

        search_nearest(tree_node, data_search, data_dim, result);

        fprintf(stderr, "n the final result: ");
        for (int i = 0; i < data_dim; i++){
                fprintf(stderr, "%lft", result[i]);
        }
        fprintf(stderr, "n");
        return 0;
}

以上的程式碼以上處至Github,其地址為:kd-tree。若有不對的地方,歡迎指正。

參考文獻