kd樹和knn演算法的c語言實現
本文轉載自http://www.cnblogs.com/LCcnblogs/p/6169136.html
樓主正在學習機器學習演算法,歡迎學習交流。
#include<stdio.h>
#include<stdlib.h>#include<math.h>
#include<time.h>
typedef struct{//資料維度
double x;
double y;
}data_struct;
typedef struct kd_node{
data_struct split_data;//資料結點
int split;//分裂維
struct kd_node *left;//由位於該結點分割超面左子空間內所有資料點構成的kd-tree
struct kd_node *right;//由位於該結點分割超面右子空間內所有資料點構成的kd-tree
}kd_struct;
//用於排序
int cmp1( const void *a , const void *b )
{
return (*(data_struct *)a).x > (*(data_struct *)b).x ? 1:-1;
}
//用於排序
int cmp2( const void *a , const void *b )
{
return (*(data_struct *)a).y > (*(data_struct *)b).y ? 1:-1;
}
//計算分裂維和分裂結點
void choose_split(data_struct data_set[],int size,int dimension,int *split,data_struct *split_data)
{
int i;
data_struct *data_temp;
data_temp=(data_struct *)malloc(size*sizeof(data_struct));
for(i=0;i<size;i++) //data_temp臨時儲存資料集
data_temp[i]=data_set[i];
static int count=0;//設為靜態
*split=(count++)%dimension;//分裂維,在第split維上進行劃分
if((*split)==0) qsort(data_temp,size,sizeof(data_temp[0]),cmp1); //qsort為內建快速排序(待排陣列,待排陣列長度,陣列元素大小,比較大小的函式的指標)
else qsort(data_temp,size,sizeof(data_temp[0]),cmp2); //split=0代表1維,=1代表2維
*split_data=data_temp[(size-1)/2];//分裂結點排在中位
}
//判斷兩個資料點是否相等
int equal(data_struct a,data_struct b){
if(a.x==b.x && a.y==b.y) return 1;
else return 0;
}
//建立KD樹
kd_struct *build_kdtree(data_struct data_set[],int size,int dimension,kd_struct *T)
{
if(size==0) return NULL;//遞迴出口
else{
int sizeleft=0,sizeright=0;
int i,split;
data_struct split_data;
choose_split(data_set,size,dimension,&split,&split_data);
data_struct data_right[size];
data_struct data_left[size];
if (split==0){//x維
for(i=0;i<size;++i){
if(!equal(data_set[i],split_data) && data_set[i].x <= split_data.x){//比分裂結點小
data_left[sizeleft].x=data_set[i].x;
data_left[sizeleft].y=data_set[i].y;
sizeleft++;//位於分裂結點的左子空間的結點數
}
else if(!equal(data_set[i],split_data) && data_set[i].x > split_data.x){//比分裂結點大
data_right[sizeright].x=data_set[i].x;
data_right[sizeright].y=data_set[i].y;
sizeright++;//位於分裂結點的右子空間的結點數
}
}
}
else{//y維
for(i=0;i<size;++i){
if(!equal(data_set[i],split_data) && data_set[i].y <= split_data.y){
data_left[sizeleft].x=data_set[i].x;
data_left[sizeleft].y=data_set[i].y;
sizeleft++;
}
else if (!equal(data_set[i],split_data) && data_set[i].y > split_data.y){
data_right[sizeright].x = data_set[i].x;
data_right[sizeright].y = data_set[i].y;
sizeright++;
}
}
}
T=(kd_struct *)malloc(sizeof(kd_struct));
T->split_data.x=split_data.x;
T->split_data.y=split_data.y;
T->split=split;
T->left=build_kdtree(data_left,sizeleft,dimension,T->left);//左子空間
T->right=build_kdtree(data_right,sizeright,dimension,T->right);//右子空間
return T;//返回指標
}
}
//計算歐氏距離
double compute_distance(data_struct a,data_struct b){
double tmp=pow(a.x-b.x,2.0)+pow(a.y-b.y,2.0);
return sqrt(tmp);
}
//搜尋1近鄰
void search_nearest(kd_struct *T,int size,data_struct test,data_struct *nearest_point,double *distance)
{
int path_size;//搜尋路徑內的指標數目
kd_struct *search_path[size];//搜尋路徑儲存各結點的指標
kd_struct* psearch=T;
data_struct nearest;//最近鄰的結點
double dist;//查詢結點與最近鄰結點的距離
search_path[0]=psearch;//初始化搜尋路徑
path_size=1;
while(psearch->left!=NULL || psearch->right!=NULL){
if (psearch->split==0){
if(test.x <= psearch->split_data.x)//如果小於就進入左子樹
psearch=psearch->left;
else
psearch=psearch->right;
}
else{
if(test.y <= psearch->split_data.y)//如果小於就進入右子樹
psearch=psearch->left;
else
psearch=psearch->right;
}
search_path[path_size++]=psearch;//將經過的分裂結點儲存在搜尋路徑中
}
//取出search_path最後一個元素,即葉子結點賦給nearest
nearest.x=search_path[path_size-1]->split_data.x;
nearest.y=search_path[path_size-1]->split_data.y;
path_size--;//search_path的指標數減一
dist=compute_distance(nearest,test);//計算與該葉子結點的距離作為初始距離
//回溯搜尋路徑
kd_struct* pback;
while(path_size!=0){
pback=search_path[path_size-1];//取出search_path最後一個結點賦給pback
path_size--;//search_path的指標數減一
if(pback->left==NULL && pback->right==NULL){//如果pback為葉子結點
if(dist>compute_distance(pback->split_data,test)){
nearest=pback->split_data;
dist=compute_distance(pback->split_data,test);
}
}
else{//如果pback為分裂結點
int s=pback->split;
if(s==0){//x維
if(fabs(pback->split_data.x-test.x)<dist){//若以查詢點為中心的圓(球或超球),半徑為dist的圓與分割超平面相交,那麼就要跳到另一邊的子空間去搜索
if(dist>compute_distance(pback->split_data,test)){
nearest=pback->split_data;
dist=compute_distance(pback->split_data, test);
}
if(test.x<=pback->split_data.x)//若查詢點位於pback的左子空間,那麼就要跳到右子空間去搜索
psearch=pback->right;
else
psearch=pback->left;//若以查詢點位於pback的右子空間,那麼就要跳到左子空間去搜索
if(psearch!=NULL)
search_path[path_size++]=psearch;//psearch加入到search_path中
}
}
else {//y維
if(fabs(pback->split_data.y-test.y)<dist){//若以查詢點為中心的圓(球或超球),半徑為dist的圓與分割超平面相交,那麼就要跳到另一邊的子空間去搜索
if(dist>compute_distance(pback->split_data,test)){
nearest=pback->split_data;
dist=compute_distance(pback->split_data,test);
}
if(test.y<=pback->split_data.y)//若查詢點位於pback的左子空間,那麼就要跳到右子空間去搜索
psearch=pback->right;
else
psearch=pback->left;//若查詢點位於pback的的右子空間,那麼就要跳到左子空間去搜索
if(psearch!=NULL)
search_path[path_size++]=psearch;//psearch加入到search_path中
}
}
}
}
(*nearest_point).x=nearest.x;//最近鄰
(*nearest_point).y=nearest.y;
*distance=dist;//距離
}
int main()
{
int n=6;//資料個數
data_struct nearest_point;
double distance;
kd_struct *root=NULL;
data_struct data_set[6]={{2,3},{5,4},{9,6},{4,7},{8,1},{7,2}};//資料集
data_struct test={7.1,2.1};//查詢點
root=build_kdtree(data_set,n,2,root);
search_nearest(root,n,test,&nearest_point,&distance);
printf("nearest neighbor:(%.2f,%.2f)\ndistance:%.2f \n",nearest_point.x,nearest_point.y,distance);
return 0;
}
/* x 5,4
/ \
y 2,3 7.2
\ / \
x 4,7 8.1 9.6
*/