1. 程式人生 > 其它 >平面歐幾里得最小生成樹

平面歐幾里得最小生成樹

定義

思路

這是要求一個完全圖的最小生成樹

那您知道有幾種求最小生成樹的方法嗎?

您可能會說:kruskal,prim,不就兩個嗎

但實際上還有第三個神奇的演算法:boruvka演算法

先說一下這個演算法的簡要流程吧

我們維護圖中所有連通塊,開始是每個點是一個單獨的連通塊,然後每一輪找到每一個連通塊和其他連通塊相連的邊權最小的一條邊,然後把這條邊加入最小生成樹,併合並連通塊。重複這樣的操作,直到最後剩下一整個連通塊。

這樣最多會有\(nlogn\)輪,因為每輪至少減少一半的連通塊的個數(實際上遠遠不到\(nlogn\)輪)

有什麼用呢?暴力不是也是\(O(n^2\log n)\)的嗎

但這裡有一個可以優化的地方,這句話:“每一輪找到每一個連通塊和其他連通塊相連的邊權最小的一條邊。”直接暴力找豈不是太過暴力?我們考慮用一個數據結構來找“和其他連通塊相連的邊權最小的一條邊”這個東西。

再一看...這不是歐幾里得距離下的最近鄰嘛...那好辦了,用KD樹維護罷。

那在每一輪裡,依次列舉所有連通塊,對所有連通塊裡的每個點都找最近鄰不就行了。

假的,因為有可能一個點和他的最近鄰在同一個連通塊裡,所以在遍歷一個連通塊之前,必須要把KD樹中當前連通塊中所有點刪掉,再查最近鄰,這樣才是對的。

至於插入和刪除,我們在每個節點上另記一個是否存在就行了。

還有,查最近鄰不是有個求到當前節點左右兒子矩形的最近距離嗎,我們認真分析一下這個最近距離在歐幾里得距離下怎麼求

我們有這個圖:

在1,3,7,9區域內,最小距離是到4個頂點的距離

而在2,4,6,8區域內,最小距離是到4個邊的距離

最後5區域最小距離就是0

這樣這題就可以解決啦

boruvka演算法最多會有\log nlogn輪,每輪裡遍歷了所有點,每個點插入一次,刪除一次,最近鄰查詢一次,所以總複雜度是常數巨大的\(O(n\log^2 n)\),最壞理論複雜度是\(O(n\sqrt{n}\log n)\),但幾乎無法卡到這個級別。

程式碼

#include<cstdio>
#include<algorithm>
using namespace std;
const int N=1e5+5;
#define F first
#define S second
int L[N],R[N],U[N],D[N],sz[N],Px[N],Py[N],loc[N],num[N],poi[N],nst[N],fa[N],Fa[N],ls[N],rs[N];
double Mn[N];
int n,treesize,cnt,Dir,X,Y,To,rt; double ans,mn;
pair<int,int> a[N]; bool exi[N];
struct Poi{ int x,y; } s[N];
bool operator < (const Poi& a, const Poi& b){ return Dir?a.y<b.y:a.x<b.x; }
int find(int u){ return Fa[u]==u?u:Fa[u]=find(Fa[u]); }
inline void Destroy(int now){ exi[now]=0,L[now]=R[now]=1e9,U[now]=D[now]=-1e9,sz[now]=0; }
inline void Renew(int now){ exi[now]=1,L[now]=R[now]=Px[now],U[now]=D[now]=Py[now],sz[now]=1; }
inline void pushup(int now){
	if (exi[now]) Renew(now); else Destroy(now);
	sz[now]+=sz[ls[now]]+sz[rs[now]];
	L[now]=min(L[now],min(L[ls[now]],L[rs[now]]));
	R[now]=max(R[now],max(R[ls[now]],R[rs[now]]));
	U[now]=min(U[now],min(U[ls[now]],U[rs[now]]));
	D[now]=max(D[now],max(D[ls[now]],D[rs[now]]));
}
void build(int& now, int l, int r, int tag){
	if (l>r) return;
	if (!now) now=++treesize;
	int mid=(l+r)>>1; Dir=tag;
	nth_element(s+l,s+mid,s+r+1);
	loc[mid]=now; num[now]=mid;
	Px[now]=L[now]=R[now]=s[mid].x;
	Py[now]=U[now]=D[now]=s[mid].y;
	sz[now]=1; exi[now]=1;
	build(ls[now],l,mid-1,tag^1); if (ls[now]) fa[ls[now]]=now;
	build(rs[now],mid+1,r,tag^1); if (rs[now]) fa[rs[now]]=now;
	pushup(now);
}
inline double dist(int x, int y){ return (double)sqrt(1ll*(x-X)*(x-X)+1ll*(y-Y)*(y-Y)); }
inline double mndis(int now){
	if (!now || L[now]==1e9) return 1e9;
	if (X<=L[now]){
		if (Y<=U[now]) return dist(L[now],U[now]);
		else if (U[now]<=Y && Y<=D[now]) return L[now]-X;
		else return dist(L[now],D[now]);
	} else
	if (L[now]<=X && X<=R[now]){
		if (Y<=U[now]) return U[now]-Y;
		else if (U[now]<=Y && Y<=D[now]) return 0;
		else return Y-D[now];
	} else {
		if (Y<=U[now]) return dist(R[now],U[now]);
		else if (U[now]<=Y && Y<=D[now]) return X-R[now];
		else return dist(R[now],D[now]);
	}
}
void query(int now){
	if (!now) return;
	if (exi[now]){
		double dis=dist(Px[now],Py[now]);
		if (dis<mn) mn=dis,To=num[now];
	}
	double dl=mndis(ls[now]),dr=mndis(rs[now]);
	if (dl<dr){
		if (dl<mn) query(ls[now]);
		if (dr<mn) query(rs[now]);
	} else {
		if (dr<mn) query(rs[now]);
		if (dl<mn) query(ls[now]);
	}
}
inline int read(){
	int num=0,fu=1; char ch=getchar();
	while (ch<'0' || ch>'9') fu&=(ch!='-'),ch=getchar();
	while (ch>='0' && ch<='9') num=(num<<3)+(num<<1)+ch-'0',ch=getchar();
	return fu?num:-num;
}
int main(){
	n=read(); L[0]=U[0]=1e9,R[0]=D[0]=-1e9;
	for (int i=1; i<=n; i++) s[i].x=read(),s[i].y=read();
	for (int i=1; i<=n; i++) Fa[i]=i;
	build(rt,1,n,0); cnt=n;
	while (cnt>1){
		int now,k=0;
		for (int i=1; i<=n; i++) a[i]=make_pair(find(i),i);
		sort(a+1,a+1+n);
		for (int i=1; i<=n; i=now+1){
			now=i; k++; Mn[k]=1e9; poi[k]=a[i].S;
			while (a[now+1].F==a[i].F) now++;
			for (int j=i; j<=now; j++){
				int u=loc[a[j].S]; Destroy(u);
				while (fa[u]) pushup(u),u=fa[u];
			}
			for (int j=i; j<=now; j++){
				X=s[a[j].S].x,Y=s[a[j].S].y;
				To=0; mn=1e9; query(rt);
				if (mn<Mn[k]) Mn[k]=mn,nst[k]=To;
			}
			for (int j=i; j<=now; j++){
				int u=loc[a[j].S]; Renew(u);
				while (fa[u]) pushup(u),u=fa[u];
			}
		}
		for (int j=1; j<=k; j++){
			int u=find(poi[j]),v=find(nst[j]);
			if (u==v) continue;
			Fa[u]=v; ans+=Mn[j]; cnt--;
		}
	}
	printf("%.6lf\n",ans);
	return 0;
}