KD樹學習筆記(只適合OIer)
阿新 • • 發佈:2018-12-22
先思考一個問題:
- 在K維空間裡面有許多的點,對於某些給定的點,我們需要找到和它最近的m個點。
- 這裡的距離指的是歐幾里得距離:
- D(p,q)=D(q,p)=sqrt((q1-p1)^2+(q2-p2)^2+(q3-p3)^2+...+ (qn-pn)^2),請你幫忙解決一下。
輸入:
- 點數n(1≤n≤50000)和維度數k(1≤k≤5)。
- 接下來的n行,每行k個整數,代表一個點的座標。
- 接下來一個正整數:給定的詢問數量t(1≤t≤10000)
- 下面2*t行:
- 第一行k個整數,表示要查詢的點的座標
- 第二行一個整數m,表示查詢最近的m個點(1≤m≤10)
- 所有座標的絕對值不超過10000。
- 有多組資料!
輸出:
- 對於每個詢問,輸出m+1行:
- 第一行:"the closest m points are:" m為查詢中的m
- 接下來m行每行代表一個點,按照從近到遠排序。
- 保證方案唯一,下面這種情況不會出現:
- 2 2
- 1 1
- 3 3
- 1
- 2 2
- 1
我們知道在二維的情況下我們可以用樹狀陣列來解決(亂搞)。但此時題目中給出了一個會變化的維度,再用樹狀陣列就會提高大量的思維難度(反正我是想象不出5維空間的),
此時我們就需要一種對應多維度的資料結構——KD樹。
- KD樹的定義:
Kd-樹是K-dimension tree的縮寫,是對資料點在k維空間(如二維(x,y),三維(x,y,z),k維(x1,y,z..))中劃分的一種資料結構,主要應用於多維空間關鍵資料的搜尋(如:範圍搜尋和最近鄰搜尋)。本質上說,Kd-樹就是一種平衡二叉樹。
首先必須搞清楚的是,k-d樹是一種空間劃分樹,說白了,就是把整個空間劃分為特定的幾個部分,然後在特定空間的部分內進行相關搜尋操作。想像一個三維(多維有點為難我的想象力了)空間,kd樹按照一定的劃分規則把這個三維空間劃分了多個空間,如下圖:
更加易懂的說法是KD樹實際上就是多關鍵字搜尋 (我蒟蒻只需要知道這個就夠了)。
2. KD樹的構建
KD樹與線段樹的構建相似,需要動態遞迴建立
void build(int &k,int l,int r,int dir)
{
int mid=(l+r)>>1;
k=mid;D=dir;
nth_element(a+l,a+mid,a+r+1,cmp);
for(int i=0;i<K;i++)
a[k].mi[i]=a[k].mx[i]=a[k].d[i];
if(l<mid)build(a[k].l,l,mid-1,(dir+1)%K);
if(r>mid)build(a[k].r,mid+1,r,(dir+1)%K);
pushup(k);
}
3.KD樹的插入
雖然此題不用插入但我們還是要學的啊
void insert(int k,int dir)
{
if (q[dir]<a[k].d[dir])
{
if (a[k].l) insert(a[k].l,(dir+1)%d);
else
{
a[k].l=++n;
for(int i=0;i<K;i++)
a[n].mi[i]=a[n].mx[i]=a[n].d[i]=q[i];
}
}
else
{
if (a[k].r) insert(a[k].r,(dir+1)%d);
else
{
a[k].r=++n;
for(int i=0;i<K;i++)
a[n].mi[i]=a[n].mx[i]=a[n].d[i]=q[i];
}
}
pushup(k);//同時向上維護
}
雖然一棵剛建好的KD樹深度是O(log)的。但隨便亂插會對時間有巨大的負擔很容易TLE。所以我們可以用替罪羊樹優化……因為博主太弱還不會(QAQ)所以請同學們自己去學習吧(學會了記得回來給我講講啊)!
4.KD樹的查詢
KD樹的關鍵。還記得我們維護的mi[]和mx[],現在我們要用它來做估計了。我們都知道估計可以省下大量的計算,所以這也是KD樹獨特的地方。但我們的答案不能是估計啊!所以精確的也不能少(QAQ)
long long Guess(int k) //估算與k點的距離值
{
long long i,s=0;
for(i=0;i<K;i++)
{
if(q[i]<a[k].mi[i])s+=(long long)(q[i]-a[k].mi[i])*(q[i]-a[k].mi[i]);
if(q[i]>a[k].mx[i])s+=(long long)(q[i]-a[k].mx[i])*(q[i]-a[k].mx[i]);
}
return s;
}
long long Dis(int k) //求查詢點與k點的距離值
{
long long i,ans=0;
for(i=0;i<K;i++)ans+=(long long)(q[i]-a[k].d[i])*(q[i]-a[k].d[i]);
return ans;
}
void Query(int x)
{
if(!x)return;
long long dis=Dis(x),dl=Guess(a[x].l),dr=Guess(a[x].r);
if(dis<Q.top().first)//為本題需要而建的大根堆
{
Q.pop();
Q.push(make_pair(dis,x));
}
if(dl<dr)
{
if(dl<Q.top().first)Query(a[x].l);
if(dr<Q.top().first)Query(a[x].r);
}
else
{
if(dr<Q.top().first)Query(a[x].r);
if(dl<Q.top().first)Query(a[x].l);
}
}
原題程式碼:
#include<bits/stdc++.h>
#define INF 0x3f3f3f3
using namespace std;
typedef pair<long long,int>pii;
priority_queue<pii>Q;
struct data{int d[6],mx[6],mi[6],l,r;}a[100005<<1];
int q[6],i,j,k,m,n,rt,D,K,t;
bool cmp(data x,data y){return x.d[D]<y.d[D];}
void pushup(int x)
{
int i,ls=a[x].l,rs=a[x].r;
for(i=0;i<K;i++)
{
if(ls)
{
a[x].mx[i]=max(a[x].mx[i],a[ls].mx[i]);
a[x].mi[i]=min(a[x].mi[i],a[ls].mi[i]);
}
if(rs)
{
a[x].mx[i]=max(a[x].mx[i],a[rs].mx[i]);
a[x].mi[i]=min(a[x].mi[i],a[rs].mi[i]);
}
}
}
void build(int &k,int l,int r,int dir)
{
int mid=(l+r)>>1;
k=mid;D=dir;
nth_element(a+l,a+mid,a+r+1,cmp);
for(int i=0;i<K;i++)
a[k].mi[i]=a[k].mx[i]=a[k].d[i];
if(l<mid)build(a[k].l,l,mid-1,(dir+1)%K);
if(r>mid)build(a[k].r,mid+1,r,(dir+1)%K);
pushup(k);
}
long long Guess(int x) //估算與x點的距離值
{
long long i,s=0;
for(i=0;i<K;i++)
{
if(q[i]<a[x].mi[i])s+=(long long)(q[i]-a[x].mi[i])*(q[i]-a[x].mi[i]);
if(q[i]>a[x].mx[i])s+=(long long)(q[i]-a[x].mx[i])*(q[i]-a[x].mx[i]);
}
return s;
}
long long Dis(int x) //求查詢點與x點的距離值
{
long long i,ans=0;
for(i=0;i<K;i++)ans+=(long long)(q[i]-a[x].d[i])*(q[i]-a[x].d[i]);
return ans;
}
void Query(int x)
{
if(!x)return;
long long dis=Dis(x),dl=Guess(a[x].l),dr=Guess(a[x].r);
if(dis<Q.top().first)
{
Q.pop();
Q.push(make_pair(dis,x));
}
if(dl<dr)
{
if(dl<Q.top().first)Query(a[x].l);
if(dr<Q.top().first)Query(a[x].r);
}
else
{
if(dr<Q.top().first)Query(a[x].r);
if(dl<Q.top().first)Query(a[x].l);
}
}
void print() //從小到大輸出m個點
{
int i,x;
while(!Q.empty())
{
x=Q.top().second;Q.pop();
print();
for(i=0;i<K;i++)printf("%d ",a[x].d[i]);
printf("\n");
}
}
int main()
{
while(scanf("%d%d",&n,&K)!=EOF)
{
memset(a,0,sizeof(a));
while(!Q.empty())Q.pop();//清空堆
for(i=1;i<=n;i++) //讀入n個點的座標
for(j=0;j<K;j++)
scanf("%d",&a[i].d[j]);
build(rt,1,n,0);
scanf("%d",&t); //建立KD樹 scanf("%d",&t);
for(i=1;i<=t;i++) //t組詢問
{
for(j=0;j<K;j++)scanf("%d",&q[j]);//讀入查詢點的座標
scanf("%d",&m);
for(j=1;j<=m;j++)Q.push(make_pair(INF,0));//把k個INF加入大根堆
Query(rt);
printf("the closest %d points are:\n",m);
print();
}
}
return 0;
}