1. 程式人生 > >knn之構造kd樹和最近鄰求取c++實現

knn之構造kd樹和最近鄰求取c++實現

這份程式碼測試樣例為 
6
7 2
2 3
5 4
4 7
9 6
8 1


8 2

這樣,通過中位數來選取根節點(這樣的方法其實在一定程度上是有很大問題的,因為根節點的選取方法不同,會導致整棵樹的結構不同,這裡由於資料的關係,不能構成完全二叉樹,所以在對於特殊的樣例來說是會出錯的,比如說(10,10)這個測試樣例,根本無法找到包含他的子節點(區域),所以會導致出錯))。

#include<iostream>
#include<algorithm>
#include<cstring>
#include<vector>
#include<cmath>
#include<queue>
using namespace std;
struct node{
    pair<int,int>x;
	int dim;
	node*left;
	node*right;
	node*father;
	node(pair<int,int>p=make_pair(0,0),int dim=0,node*left=0,node*right=0,node*father=0)
	:dim(dim),left(left),right(right),father(father)
	{
	 x=p;	
	}
};
bool cmp1(node*a,node* b)
{
	return a->x.first<b->x.first;
}
bool cmp2(node*a,node*b)
{
	return a->x.second<b->x.second;
}
vector<node*>vec;
node* buildtree(vector<node*>temp,int cnt)
{ 

if(temp.size()==0)
return 0;
else if(temp.size()==1)
return temp[0];
else{

    if(cnt==1)
 	sort(temp.begin(),temp.end(),cmp1);
 	else
 	sort(temp.begin(),temp.end(),cmp2);
 	
 	int mid=temp.size()/2;
    vector<node*>p;
    for(int i=0;i<mid;i++)
    {
    	p.push_back(temp[i]);
    }
    vector<node*>q;
    for(int i=mid+1;i<temp.size();i++)
    {
    	q.push_back(temp[i]);
    }
    node*left=buildtree(p,(cnt+1)%2);
    node*right=buildtree(q,(cnt+1)%2);
    node*fat=new node(make_pair(temp[mid]->x.first,temp[mid]->x.second),cnt,left,right,0);
    if(left!=0)
	left->father=fat;
	if(right!=0)
    right->father=fat;
    //cout<<fat->x.first<<" "<<fat->x.second<<endl;
    return fat;
}
}
void traverse(node*root) 
{
	if(root==0)
	{
	}
	else
	{
		cout<<root->x.first<<" "<<root->x.second<<endl;
		traverse(root->left);
		traverse(root->right);
	}
}
node*find_first_belong(node*key,node*root)
{
	node*temp=root; 
	while(true) //遍歷找到其歸屬的葉節點 
	{
		if(temp->left==0&&temp->right==0)
		{
			
			break;
		}
		else
		{
			int dim=temp->dim;//選擇維度比較 
			if(dim==1)//選擇x1比較 
			{
				if(key->x.first<=temp->x.first)
				temp=temp->left;
				else
				temp=temp->right;
			}
			else //選擇x2比較 
			{
				if(key->x.second<=temp->x.second)
				 temp=temp->left;
				 else
				 temp=temp->right;
			}
		}
	}
	return temp;
}
double distance(node*a,node*b)
{
	double ax1=a->x.first;
	double ax2=a->x.second;
	double bx1=b->x.first;
	double bx2=b->x.second;
	return sqrt(pow(ax1-bx1,2)+pow(ax2-bx2,2));
	
}
node*query(node*key,node*root,double mindis)
//這裡就是最不明白的一點,當另一區域跟圓相交,書上說是遞迴進行最近鄰搜尋,
//沒搞懂到底怎麼遞迴搜尋,所以這裡就直接用了很簡單的遍歷比較,希望以後能搞懂 
{
	node*rec=root;
	double mind=mindis;
	queue<node*>q;
	q.push(root);
	while(!q.empty())
	{
		node*temp=q.front();
		double dis=distance(key,temp);
		if(dis<mind)
		{
			mind=dis;
			rec=temp;
		}
		q.pop();
		if(temp->left!=0)
		q.push(temp->left);
		if(temp->right)
		q.push(temp->right);
		
	}
	return rec;
}
node*find_nearest(node*key,node*belong)
{
	node *nearest=belong;
 	double mindis=distance(key,belong);
 	//cout<<mindis<<" mindis"<<endl;
 	while(true)
 	{
 	//cout<<belong->x.first<<" "<<belong->x.second<<endl;
 	node*fat=belong->father;
 	if(fat==0)
 	break;
 	int fdim=fat->dim;
 	if(distance(fat,key)<mindis)
 	{
 		mindis=distance(fat,key);
 		nearest=fat;
 	}
 	if(fdim==1) //判斷圓是否與x1=fat->x.first相交 
 	{
 		int fx1=fat->x.first;
 		int kx1=key->x.first;
 		if(abs(fx1-kx1)<mindis)
 		{
 			node*res=query(key,fat->right,mindis);
 			if(res!=0&&distance(res,key)<mindis)
 			{
 				nearest=res;
 				mindis=distance(res,key);
 			}
 			
 		}
 	}
 	else //反之 
 	{
 		 int fx2=fat->x.second;
 		 int kx2=key->x.second;
 		 if(abs(fx2-kx2)<mindis)
 		 {
 		 	node*res=query(key,fat->right,mindis);
 			if(res!=0&&distance(res,key)<mindis)
 			{
 				nearest=res;
 				mindis=distance(res,key);
 				
 			}
 		 }
 	}
 	belong=fat;
 	if(belong==0)
 	break;
  }
  return nearest;
}
node*search(node*key,node*root)
{
	node* belong=find_first_belong(key,root);
	//cout<<belong->x.first<<" "<<belong->x.second<<endl;
	node* nearest=find_nearest(key,belong);
}
int main()
{
	int n;
	cin>>n;
	for(int i=0;i<n;i++)
	{
		int x,y;
		cin>>x>>y;
		node* temp=new node(make_pair(x,y));
		vec.push_back(temp);
	}
	node*root=buildtree(vec,1);
	//traverse(root);
	int x,y;
	cin>>x>>y;
	node *key=new node(make_pair(x,y));
    node*near=search(key,root);
	cout<<near->x.first<<" "<<near->x.second<<endl;
	
} 
以上程式碼,經過測試,除了(10,10)這種類似的特殊資料會出錯,別的基本正確,程式碼寫的很亂。。。。

這裡還有一個很大的問題在於,我不知道一旦判定了圓和其他區域相交之後該怎麼進行遞迴搜尋,所以這裡直接用了遍歷。。。。

總算搞懂了什麼遞迴搜尋:

下面的是第二個版本:

#include<iostream>
#include<algorithm>
#include<cstring>
#include<vector>
#include<cmath>
#include<queue>
using namespace std;
struct node{
    pair<int,int>x;
	int dim;
	node*left;
	node*right;
	node*father;
	node(pair<int,int>p=make_pair(0,0),int dim=0,node*left=0,node*right=0,node*father=0)
	:dim(dim),left(left),right(right),father(father)
	{
	 x=p;	
	}
};
bool cmp1(node*a,node* b)
{
	return a->x.first<b->x.first;
}
bool cmp2(node*a,node*b)
{
	return a->x.second<b->x.second;
}
vector<node*>vec;
node* buildtree(vector<node*>temp,int cnt)
{ 

if(temp.size()==0)
return 0;
else if(temp.size()==1)
return temp[0];
else{

    if(cnt==1)
 	sort(temp.begin(),temp.end(),cmp1);
 	else
 	sort(temp.begin(),temp.end(),cmp2);
 	
 	int mid=temp.size()/2;
    vector<node*>p;
    for(int i=0;i<mid;i++)
    {
    	p.push_back(temp[i]);
    }
    vector<node*>q;
    for(int i=mid+1;i<temp.size();i++)
    {
    	q.push_back(temp[i]);
    }
    node*left=buildtree(p,(cnt+1)%2);
    node*right=buildtree(q,(cnt+1)%2);
    node*fat=new node(make_pair(temp[mid]->x.first,temp[mid]->x.second),cnt,left,right,0);
    if(left!=0)
	left->father=fat;
	if(right!=0)
    right->father=fat;
    //cout<<fat->x.first<<" "<<fat->x.second<<endl;
    return fat;
}
}
void traverse(node*root) 
{
	if(root==0)
	{
	}
	else
	{
		cout<<root->x.first<<" "<<root->x.second<<endl;
		traverse(root->left);
		traverse(root->right);
	}
}
node*find_first_belong(node*key,node*root)
{
	node*temp=root; 
	while(true) //遍歷找到其歸屬的葉節點 
	{
		if(temp->left==0&&temp->right==0)
		{
			
			break;
		}
		else
		{
			int dim=temp->dim;//選擇維度比較 
			if(dim==1)//選擇x1比較 
			{
				if(key->x.first<=temp->x.first)
				temp=temp->left;
				else
				temp=temp->right;
			}
			else //選擇x2比較 
			{
				if(key->x.second<=temp->x.second)
				 temp=temp->left;
				 else
				 temp=temp->right;
			}
		}
	}
	return temp;
}
double distance(node*a,node*b)
{
	double ax1=a->x.first;
	double ax2=a->x.second;
	double bx1=b->x.first;
	double bx2=b->x.second;
	return sqrt(pow(ax1-bx1,2)+pow(ax2-bx2,2));
	
}
node*query(node*key,node*root,double mindis)//沒有用的函式
 
{
	node*rec=root;
	double mind=mindis;
	queue<node*>q;
	q.push(root);
	while(!q.empty())
	{
		node*temp=q.front();
		double dis=distance(key,temp);
		if(dis<mind)
		{
			mind=dis;
			rec=temp;
		}
		q.pop();
		if(temp->left!=0)
		q.push(temp->left);
		if(temp->right)
		q.push(temp->right);
		
	}
	return rec;
}
node*find_nearest(node*key,node*belong,node*root)
{
	node *nearest=belong;
 	double mindis=distance(key,belong);
 	//cout<<belong->x.first<<" belong "<<belong->x.second<<endl;
 	//cout<<mindis<<" mindis"<<endl;
 	while(true)
 	{
 	//cout<<belong->x.first<<" "<<belong->x.second<<endl;
 	node*fat=belong->father;
 	if(fat==0||fat==root->father)
 	break;
 	node*other=new node(); //相比第一個這裡還更加對了,因為這裡還考慮到了萬一歸屬的葉節點不是左節點的情況
 	if(fat->left==belong)
 	{
 		other=fat->right;
 		
 	}
 	else
 	other=fat->left;
 	
 	//cout<<fat->x.first<<" "<<" fat  "<<fat->x.second<<endl;
 	int fdim=fat->dim;
 	if(distance(fat,key)<mindis)
 	{
 		mindis=distance(fat,key);
 		nearest=fat;
 	}
 	if(fdim==1) //判斷圓是否與x1=fat->x.first相交 
 	{
 		int fx1=fat->x.first;
 		int kx1=key->x.first;
 		if(abs(fx1-kx1)<mindis)
 		{
 			node*tm=find_first_belong(key,other);
 			node*res=find_nearest(key,tm,other); //傳說中的遞迴搜尋在這裡,利用他之前的函式
 			if(res!=0&&distance(res,key)<mindis)
 			{
 				nearest=res;
 				mindis=distance(res,key);
 			}
 			
 		}
 		//cout<<fx1<<" xxxx   "<<kx1<<" "<<mindis<<endl;
 	}
 	else //反之 
 	{
 		 int fx2=fat->x.second;
 		 int kx2=key->x.second;
 		 if(abs(fx2-kx2)<mindis)
 		 {
 		 	node*tm=find_first_belong(key,other);
 		 	//cout<<tm->x.first<<" **** "<<tm->x.second<<endl;
 		 	//cout<<other->x.first<<" other "<<other->x.second<<endl;
 			node*res=find_nearest(key,tm,other);
 			if(res!=0&&distance(res,key)<mindis)
 			{
 				nearest=res;
 				mindis=distance(res,key);
 				//cout<<mindis<<"  mindis"<<endl;
 			}
 		 }
 	}
 	belong=fat;
 	if(belong==0)
 	break;
  }
  return nearest;
}
node*search(node*key,node*root)
{
	node* belong=find_first_belong(key,root);
	//cout<<belong->x.first<<" "<<belong->x.second<<endl;
	node* nearest=find_nearest(key,belong,root);
	return nearest;
}
int main()
{
	int n;
	cin>>n;
	for(int i=0;i<n;i++)
	{
		int x,y;
		cin>>x>>y;
		node* temp=new node(make_pair(x,y));
		vec.push_back(temp);
	}
	node*root=buildtree(vec,1);
	//traverse(root);
	int x,y;
	cin>>x>>y;
	node *key=new node(make_pair(x,y));
    node*near=search(key,root);
	cout<<"the nearest point is "<<near->x.first<<" "<<near->x.second<<endl;
	
} 
然而還是沒有解決(10,10)的情況,明天再說!!!!