最近鄰演算法的實現:k-d tree
阿新 • • 發佈:2019-01-08
#include <iostream> #include <algorithm> #include <stack> #include <math.h> using namespace std; /*function of this program: build a 2d tree using the input training data the input is exm_set which contains a list of tuples (x,y) the output is a 2d tree pointer*/ struct data { double x = 0; double y = 0; }; struct Tnode { struct data dom_elt; int split; struct Tnode * left; struct Tnode * right; }; bool cmp1(data a, data b){ return a.x < b.x; } bool cmp2(data a, data b){ return a.y < b.y; } bool equal(data a, data b){ if (a.x == b.x && a.y == b.y) { return true; } else{ return false; } } void ChooseSplit(data exm_set[], int size, int &split, data &SplitChoice){ /*compute the variance on every dimension. Set split as the dismension that have the biggest variance. Then choose the instance which is the median on this split dimension.*/ /*compute variance on the x,y dimension. DX=EX^2-(EX)^2*/ double tmp1,tmp2; tmp1 = tmp2 = 0; for (int i = 0; i < size; ++i) { tmp1 += 1.0 / (double)size * exm_set[i].x * exm_set[i].x; tmp2 += 1.0 / (double)size * exm_set[i].x; } double v1 = tmp1 - tmp2 * tmp2; //compute variance on the x dimension tmp1 = tmp2 = 0; for (int i = 0; i < size; ++i) { tmp1 += 1.0 / (double)size * exm_set[i].y * exm_set[i].y; tmp2 += 1.0 / (double)size * exm_set[i].y; } double v2 = tmp1 - tmp2 * tmp2; //compute variance on the y dimension split = v1 > v2 ? 0:1; //set the split dimension if (split == 0) { sort(exm_set,exm_set + size, cmp1); } else{ sort(exm_set,exm_set + size, cmp2); } //set the split point value SplitChoice.x = exm_set[size / 2].x; SplitChoice.y = exm_set[size / 2].y; } Tnode* build_kdtree(data exm_set[], int size, Tnode* T){ //call function ChooseSplit to choose the split dimension and split point if (size == 0){ return NULL; } else{ int split; data dom_elt; ChooseSplit(exm_set, size, split, dom_elt); data exm_set_right [100]; data exm_set_left [100]; int sizeleft ,sizeright; sizeleft = sizeright = 0; if (split == 0) { for (int i = 0; i < size; ++i) { if (!equal(exm_set[i],dom_elt) && exm_set[i].x <= dom_elt.x) { exm_set_left[sizeleft].x = exm_set[i].x; exm_set_left[sizeleft].y = exm_set[i].y; sizeleft++; } else if (!equal(exm_set[i],dom_elt) && exm_set[i].x > dom_elt.x) { exm_set_right[sizeright].x = exm_set[i].x; exm_set_right[sizeright].y = exm_set[i].y; sizeright++; } } } else{ for (int i = 0; i < size; ++i) { if (!equal(exm_set[i],dom_elt) && exm_set[i].y <= dom_elt.y) { exm_set_left[sizeleft].x = exm_set[i].x; exm_set_left[sizeleft].y = exm_set[i].y; sizeleft++; } else if (!equal(exm_set[i],dom_elt) && exm_set[i].y > dom_elt.y) { exm_set_right[sizeright].x = exm_set[i].x; exm_set_right[sizeright].y = exm_set[i].y; sizeright++; } } } T = new Tnode; T->dom_elt.x = dom_elt.x; T->dom_elt.y = dom_elt.y; T->split = split; T->left = build_kdtree(exm_set_left, sizeleft, T->left); T->right = build_kdtree(exm_set_right, sizeright, T->right); return T; } } double Distance(data a, data b){ double tmp = (a.x - b.x) * (a.x - b.x) + (a.y - b.y) * (a.y - b.y); return sqrt(tmp); } void searchNearest(Tnode * Kd, data target, data &nearestpoint, double & distance){ //1. 如果Kd是空的,則設dist為無窮大返回 //2. 向下搜尋直到葉子結點 stack<Tnode*> search_path; Tnode* pSearch = Kd; data nearest; double dist; while(pSearch != NULL) { //pSearch加入到search_path中; search_path.push(pSearch); if (pSearch->split == 0) { if(target.x <= pSearch->dom_elt.x) /* 如果小於就進入左子樹 */ { pSearch = pSearch->left; } else { pSearch = pSearch->right; } } else{ if(target.y <= pSearch->dom_elt.y) /* 如果小於就進入左子樹 */ { pSearch = pSearch->left; } else { pSearch = pSearch->right; } } } //取出search_path最後一個賦給nearest nearest.x = search_path.top()->dom_elt.x; nearest.y = search_path.top()->dom_elt.y; search_path.pop(); dist = Distance(nearest, target); //3. 回溯搜尋路徑 Tnode* pBack; while(search_path.size() != 0) { //取出search_path最後一個結點賦給pBack pBack = search_path.top(); search_path.pop(); if(pBack->left == NULL && pBack->right == NULL) /* 如果pBack為葉子結點 */ { if( Distance(nearest, target) > Distance(pBack->dom_elt, target) ) { nearest = pBack->dom_elt; dist = Distance(pBack->dom_elt, target); } } else { int s = pBack->split; if (s == 0) { if( fabs(pBack->dom_elt.x - target.x) < dist) /* 如果以target為中心的圓(球或超球),半徑為dist的圓與分割超平面相交, 那麼就要跳到另一邊的子空間去搜索 */ { if( Distance(nearest, target) > Distance(pBack->dom_elt, target) ) { nearest = pBack->dom_elt; dist = Distance(pBack->dom_elt, target); } if(target.x <= pBack->dom_elt.x) /* 如果target位於pBack的左子空間,那麼就要跳到右子空間去搜索 */ pSearch = pBack->right; else pSearch = pBack->left; /* 如果target位於pBack的右子空間,那麼就要跳到左子空間去搜索 */ if(pSearch != NULL) //pSearch加入到search_path中 search_path.push(pSearch); } } else { if( fabs(pBack->dom_elt.y - target.y) < dist) /* 如果以target為中心的圓(球或超球),半徑為dist的圓與分割超平面相交, 那麼就要跳到另一邊的子空間去搜索 */ { if( Distance(nearest, target) > Distance(pBack->dom_elt, target) ) { nearest = pBack->dom_elt; dist = Distance(pBack->dom_elt, target); } if(target.y <= pBack->dom_elt.y) /* 如果target位於pBack的左子空間,那麼就要跳到右子空間去搜索 */ pSearch = pBack->right; else pSearch = pBack->left; /* 如果target位於pBack的右子空間,那麼就要跳到左子空間去搜索 */ if(pSearch != NULL) // pSearch加入到search_path中 search_path.push(pSearch); } } } } nearestpoint.x = nearest.x; nearestpoint.y = nearest.y; distance = dist; } int main(){ data exm_set[100]; //assume the max training set size is 100 double x,y; int id = 0; cout<<"Please input the training data in the form x y. One instance per line. Enter -1 -1 to stop."<<endl; while (cin>>x>>y){ if (x == -1) { break; } else{ exm_set[id].x = x; exm_set[id].y = y; id++; } } struct Tnode * root = NULL; root = build_kdtree(exm_set, id, root); data nearestpoint; double distance; data target; cout <<"Enter search point"<<endl; while (cin>>target.x>>target.y) { searchNearest(root, target, nearestpoint, distance); cout<<"The nearest distance is "<<distance<<",and the nearest point is "<<nearestpoint.x<<","<<nearestpoint.y<<endl; cout <<"Enter search point"<<endl; } }