資料結構和演算法——kd樹
一、K-近鄰演算法
K-近鄰演算法是一種典型的無參監督學習演算法,對於一個監督學習任務來說,其mm個訓練樣本為:
{(X(1),y(1)),(X(2),y(2)),⋯,(X(m),y(m))}
left { left ( X^{left ( 1 right )},y^{left ( 1 right )} right ),left ( X^{left ( 2 right )},y^{left ( 2 right )} right ),cdots ,left ( X^{left ( m right )},y^{left ( m right )} right ) right }
在K-近鄰演算法中,無需利用訓練樣本學習出統一的模型,對於一個新的樣本,如XX,通過比較樣本XX與mm個訓練樣本的相似度,選擇出kk個最相似的樣本,並以這kk個樣本的標籤作為樣本XX的標籤。
在如上的描述中,樣本XX需要分別與mm個訓練樣本計算相似度,通常,使用的相似度的計算方法為歐式距離,即對於樣本Xi={xi,1,xi,2,⋯,xi,n}X_i=left { x_{i,1},x_{i,2},cdots ,x_{i,n} right }和樣本Xj={xj,1,xj,2,⋯,xj,n}X_j=left { x_{j,1},x_{j,2},cdots ,x_{j,n} right },其兩者之間的相似度為:
S=∑t=1n(xi,t−xj,t)2−−−−−−−−−−−−−√
S=sqrt{sum_{t=1}^{n}left ( x_{i,t}-x_{j,t} right )^2}
對於K-近鄰演算法的具體過程,可以參見博文簡單易學的機器學習演算法——K-近鄰演算法。
在K-近鄰演算法的計算過程中,通過暴力的對每一對樣本計算其相似度是非常好費時間的,那麼是否存在一種方法,能夠加快計算的速度?kd樹便是其中的一種方法。
二、kd樹
kd樹是一種對kk維空間中的例項點進行儲存以便對其進行快速檢索的樹形資料結構,且kd樹是一種二叉樹,表示對kk維空間的一個劃分。
1、二叉排序樹
在資料結構中,二叉排序樹又稱二叉查詢樹或者二叉搜尋樹。其定義為:二叉排序樹,或者是一棵空樹,或者是具有下列性質的二叉樹:
- 若它的左子樹不空,則左子樹上所有結點的值均小於它的根結點的值;
- 若它的右子樹不空,則右子樹上所有結點的值均大於它的根結點的值;
- 它的左、右子樹也分別為二叉排序樹。
一個典型的二叉排序樹的例子如下圖所示:
在二叉排序樹中,若以中序遍歷,則得到的是按照值大小排序的結果,即1->3->4->6->7->8->10->13->14。
如果需要檢索7,則從根結點開始:
- 7<87<8->左子樹
- 7>37>3->右子樹
- 7>67>6->右子樹
- 7=77=7->查詢結束
但是,對於二叉排序樹的建立,若構建二叉排序樹的順序為基本有序時,如按照1->3->4->6->7->8->10->13->14構建二叉排序樹,會得到如下的結果:
這樣的話,檢索效率會下降,為了避免這樣的情況的出現,會對二叉樹設定一些條件,如平衡二叉樹。對於二叉排序樹的更多內容,可以參見資料結構和演算法——二叉排序樹。
2、kd樹的概念
kd樹與二叉排序樹的基本思想類似,與二叉排序樹不同的是,在kd樹中,每一個節點表示的是一個樣本,通過選擇樣本中的某一維特徵,將樣本劃分到不同的節點中,如對於樣本{(7,2),(5,4),(9,6)}left { left ( 7,2 right ),left ( 5,4 right ),left ( 9,6 right ) right }, 考慮資料的第一維,首先,根節點為{(7,2)}left { left ( 7,2 right )right },由於樣本{(5,4)}left { left ( 5,4 right )right }的第一維55小於77,因此,樣本{(5,4)}left { left ( 5,4 right )right }在根節點的左子樹上,同理,樣本{(9,6)}left { left ( 9,6 right )right }在根節點的右子樹上。通過第一維可以構建如下的二叉樹模型:
在kd樹的基本操作中,主要包括kd樹的建立和kd樹的檢索兩個部分。
3、kd樹的建立
構造kd樹相當於不斷地用垂直於座標軸的超平面將kk維空間切分成一系列的kk維超矩陣區域。選擇劃分節點的方法主要有兩種:
- 順序選擇,即按照資料的順序依次在kd樹中插入節點;
- 選擇待劃分維數的中位數為劃分的節點。在kd樹的構建過程中,為了防止出現只有左子樹或者只有右子樹的情況出現,通常對於每一個節點,選擇樣本中的中位數作為切分點。這樣構建出來的kd樹時平衡的。
在李航的《統計機器學習》P41中有提到:平衡的kd樹搜尋時的效率未必是最優的。
在構建kd樹的過程中,也可以根據插入資料的順序構建kd樹,以二維資料集為例,其資料的順序依次為:
{(3,6),(7,5),(3,1),(6,2),(9,1),(2,7)}
left { left ( 3,6 right ),left ( 7,5 right ),left ( 3,1 right ),left ( 6,2 right ),left ( 9,1 right ),left ( 2,7 right ) right }
對於如上的二維資料集,構建kd樹:
- 選擇一維最為切分的維度,如選擇第00維,第一個數為(3,6)left ( 3,6 right ),其第00維的值為33,以(3,6)left ( 3,6 right )作為kd樹的根結點,若第00維的值大於33為右子樹,否則插入到左子樹中;
- 對後續的節點依次判斷,如(7,5)left ( 7,5 right ),選擇第00維,其值為77,大於33,插入到根結點的右子樹中,設定其維數為除了第00維以外的任一維。。。
按照如上的過程,我們劃分出來的kd樹如下圖所示:
此時,將樣本按照特徵空間劃分如下圖所示:
由以上的計算過程可以看出對於樹中節點,需要有資料項,當前節點的比較維度,指向左子樹的指標和指向右子樹的指標,可以設定其結構如下:
#define MAX_LEN 1024
typedef struct KDtree{
double data[MAX_LEN]; // 資料
int dim; // 選擇的維度
struct KDtree *left; // 左子樹
struct KDtree *right; // 右子樹
}kdtree_node;
構造kd樹的函式宣告為:
int kdtree_insert(kdtree_node *&tree_node, double *data, int layer, int dim);
函式的具體實現如下:
// 遞迴構建kd樹,通過節點所在的層數控制選擇的維度
int kdtree_insert(kdtree_node * &tree_node, double *data, int layer, int dim){
// 空樹
if (NULL == tree_node){
// 申請空間
tree_node = (kdtree_node *)malloc(sizeof(kdtree_node));
if (NULL == tree_node) return 1;
//插入元素
for (int i = 0; i < dim; i ++){
(tree_node->data)[i] = data[i];
}
tree_node->dim = layer % (dim);
tree_node->left = NULL;
tree_node->right = NULL;
return 0;
}
// 插入左子樹
if (data[tree_node->dim] <= (tree_node->data)[tree_node->dim]){
return kdtree_insert(tree_node->left, data, ++layer, dim);
}
// 插入右子樹
return kdtree_insert(tree_node->right, data, ++layer, dim);
}
當構建好了kd樹後,需要對kd樹進行遍歷,在這裡,實現了兩種kd樹的遍歷方法:
- 先序遍歷
- 中序遍歷
對於先序遍歷,其函式的宣告為:
void kdtree_print(kdtree_node *tree, int dim);
函式的具體實現為:
void kdtree_print(kdtree_node *tree, int dim){
if (tree != NULL){
fprintf(stderr, "dim:%dn", tree->dim);
for (int i = 0; i < dim; i++){
fprintf(stderr, "%lft", (tree->data)[i]);
}
fprintf(stderr, "n");
kdtree_print(tree->left, dim);
kdtree_print(tree->right, dim);
}
}
對於中序遍歷,其函式的宣告為:
void kdtree_print_in(kdtree_node *tree, int dim);
函式的具體實現為:
void kdtree_print_in(kdtree_node *tree, int dim){
if (tree != NULL){
kdtree_print_in(tree->left, dim);
fprintf(stderr, "dim:%dn", tree->dim);
for (int i = 0; i < dim; i++){
fprintf(stderr, "%lft", (tree->data)[i]);
}
fprintf(stderr, "n");
kdtree_print_in(tree->right, dim);
}
}
4、kd樹的檢索
與二叉排序樹一樣,在kd樹中,將樣本劃分到不同的空間中,在查詢的過程中,由於查詢在某些情況下僅需查詢部分的空間,這為查詢的過程節省了對大部分資料點的搜尋的時間,對於kd樹的檢索,其具體過程為:
- 從根節點開始,將待檢索的樣本劃分到對應的區域中(在kd樹形結構中,從根節點開始查詢,直到葉子節點,將這樣的查詢序列儲存到棧中)
- 以棧頂元素與待檢索的樣本之間的距離作為最短距離min_distance
- 執行出棧操作:
- 向上回溯,查詢到父節點,若父節點與待檢索樣本之間的距離小於當前的最短距離min_distance,則替換當前的最短距離min_distance
- 以待檢索的樣本為圓心(二維,高維情況下是球心),以min_distance為半徑畫圓,若圓與父節點所在的平面相割,則需要將父節點的另一棵子樹進棧,重新執行以上的出棧操作
- 直到棧為空
以查詢(6,3)left ( 6,3 right )為例,首先,我們需要找到待查詢的樣本所在的搜尋空間,搜尋空間如下圖中的黑色區域所示:
其對應的進棧序列為:{(3,6),(7,5),(6,2)}left { left ( 3,6 right ),left ( 7,5 right ),left ( 6,2 right ) right }。
此時,以到(6,2)left ( 6,2 right )之間的距離為最短距離,最短距離min_distance為1,對棧頂元素出棧,此時棧中的序列為:{(3,6),(7,5)}left { left ( 3,6 right ),left ( 7,5 right ) right }。以待檢索樣本(6,3)left ( 6,3 right )為圓心,1為半徑畫圓,圓與(6,2)left ( 6,2 right )所在平面相割,如下圖所示:
此時,需要檢索以(6,2)left ( 6,2 right )為根節點的另外一棵子樹,即需要將(9,1)left ( 9,1 right )進棧,此時,棧中的序列為:{(3,6),(7,5),(9,1)}left { left ( 3,6 right ),left ( 7,5 right ),left ( 9,1 right ) right }。
注意:若需要進棧的子樹中有很多節點,則根據需要比較的元素的大小,將直到葉節點的所有節點都進棧,這一點在很多地方都寫得不清楚。
按照上述的步驟,再執行出棧的操作,直到棧為空。
檢索過程的函式宣告為:
void search_nearest(kdtree_node *tree, double *data_search, int dim, double *result);
函式的具體實現為:
void search_nearest(kdtree_node *tree, double *data_search, int dim, double *result){
// 一直找到葉子節點
fprintf(stderr, "nstart searching....n");
stack<kdtree_node *> st;
kdtree_node *p = tree;
while (p->left != NULL || p->right != NULL){
st.push(p);// 將p壓棧
if (data_search[p->dim] <= (p->data)[p->dim]){// 選擇左子樹
// 判斷左子樹是否為空
if (p->left == NULL) break;
p = p->left;
}else{ // 選擇右子樹
if (p->right == NULL) break;
p = p->right;
}
}
// 現在與棧中的資料進行對比
double min_distance = distance(data_search, p->data, dim);// 與根結點之間的距離
fprintf(stderr, "init: %lfn", min_distance);
copy2result(p->data, result, dim);
// 列印最優值
for (int i = 0; i < dim; i++){
fprintf(stderr, "%lft", result[i]);
}
fprintf(stderr, "n");
double d = 0;
while (st.size() > 0){
kdtree_node *q = st.top();// 找到棧頂元素
st.pop(); // 出棧
// 判斷與父節點之間的距離
d = distance(data_search, q->data, dim);
if (d <= min_distance){
min_distance = d;
copy2result(q->data, result, dim);
}
// 判斷與分隔面是否相交
double d_line = distance_except_dim(data_search, q->data, q->dim); // 到平面之間的距離
if (d_line < min_distance){ // 相交
// 如果本來在右子樹,現在查詢左子樹
// 如果本來在左子樹,現在查詢右子樹
if (data_search[q->dim] > (q->data)[q->dim]){
// 選擇左子樹
if (q->left != NULL) q = q->left;
else q = NULL;
}else{
// 選擇右子樹
if (q->right != NULL) q = q->right;
else q = NULL;
}
if (q != NULL){
while (q->left != NULL || q->right != NULL){
st.push(q);
if (data_search[q->dim] <= (q->data)[q->dim]){
if (q->left == NULL) break;
q = q->left;
}else{
if (q->right == NULL) break;
q = q->right;
}
}
if (q->left == NULL && q->right == NULL) st.push(q);
}
}
}
}
在函式的實現中,需要用到的函式為:
- 兩個樣本之間的距離
double distance(double *a, double *b, int dim){
double d = 0.0;
for (int i = 0; i < dim; i ++){
d += (a[i] - b[i]) * (a[i] - b[i]);
}
return d;
}
- 待檢索的樣本到平面之間的距離
double distance_except_dim(double *a, double *b, int except_dim){
double d = (a[except_dim] - b[except_dim]) * (a[except_dim] - b[except_dim]);
return d;
}
- 複製最優的結果
void copy2result(double *a, double *result, int dim){
for (int i = 0; i < dim; i ++){
result[i] = a[i];
}
}
三、測試
利用如上的測試集,我們構建kd樹,並在kd樹中查詢(6,3)left ( 6,3 right ),測試程式碼如下:
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include "kdtree.h"
// 解析特徵
int parse_feature(char *p, double *fea, int *dim){
// 解析特徵
char *q = p;
int i = 0;
while ((q = strchr(p, 't')) != NULL){
*q = 0;
fea[i] = atof(p);
//fprintf(stderr, "atof(p):%lfn", atof(p));
p = q + 1;
//r = r + 1;
i += 1;
}
// 解析最後一個
fea[i] = atof(p);
*dim = i + 1;
//fprintf(stderr, "atof(p):%lfn", atof(p));
//fprintf(stderr, "fea:%lft%lfn", fea[0], fea[1]);
}
int main(){
kdtree_node *tree_node = NULL;
// 從檔案中讀入資料
FILE *fp = fopen("data.txt", "r");
char feature[MAX_LEN];
double data[MAX_LEN];
int data_dim = 0; // 資料的維數
double data_search[2] = {6.0, 3.0};
while (fgets(feature, MAX_LEN, fp)){
fprintf(stderr, "%s", feature);
parse_feature(feature, data, &data_dim);
fprintf(stderr, "distance: %lfn", distance(data, data_search, data_dim));
// 插入到kd樹中
kdtree_insert(tree_node, data, 0, data_dim);
}
fclose(fp);
fprintf(stderr, "dim:%dn", data_dim);
fprintf(stderr, "insert_okn");
// test
kdtree_print(tree_node, data_dim);
printf("n");
kdtree_print_in(tree_node, data_dim);
double result[2];
search_nearest(tree_node, data_search, data_dim, result);
fprintf(stderr, "n the final result: ");
for (int i = 0; i < data_dim; i++){
fprintf(stderr, "%lft", result[i]);
}
fprintf(stderr, "n");
return 0;
}
以上的程式碼以上處至Github,其地址為:kd-tree。若有不對的地方,歡迎指正。