1. 程式人生 > 其它 >「題解」樹套樹 tree

「題解」樹套樹 tree

本文將同步釋出於:

題目

題目描述

給你一個 \(n\) 個點的小樹(正常的樹),給你一個 \(m\) 個點的大樹,大樹的節點是一棵小樹,大樹的邊是跨越了兩棵小樹之間的邊,\(q\) 次詢問,求樹上距離。

\(1\leq n,m,q\leq 4\times 10^4\)

題解

預處理

思路非常簡單,我們顯然可以通過一系列操作 \(\Theta(n)\)\(\Theta(n\log_2n)\) 預處理,使得可以在 \(\Theta(1)\) 或者 \(\Theta(\log_2n)\) 求出小樹任意兩點間的距離。

大樹倍增

我們在大樹的每個節點儲存一點資訊:

  • \(\texttt{fa}_i\):編號為 \(i\) 的大樹節點在大樹上的祖先為 \(\texttt{fa}_i\)
  • \(\texttt{rt}_i\):編號為 \(i\) 的大樹節點連線 \(\texttt{fa}_i\) 對應小樹節點為 \(\texttt{rt}_i\)
  • \(\texttt{ptr}_i\)\(\texttt{rt}_i\) 在實際的樹中對應的祖先,也就是編號為 \(\texttt{fa}_i\) 中與 \(i\) 相連的小樹節點編號。

維護了以上資訊後,我們再維護 \(\texttt{dis}_i\),表示 \(\texttt{rt}_i\) 到實際的樹的根的距離。

然後直接倍增加分類討論即可解決問題。

優化時間複雜度

不難看出,最簡單的做法的時間複雜度為 \(\Theta\left(n\log_2n+m\left(\log_2n+\log_2m\right)+q\left(\log_2n+\log_2m\right)\right)\)

我們可以通過 \(\Theta(n)\) 構造的 ST 表輕鬆將複雜度降到 \(\Theta(n+m+q)\),考慮到程式碼複雜度偏大,就沒有具體實現。

參考程式

參考程式的時間複雜度為 \(\Theta\left(n+m\left(\log_2n+\log_2m\right)+q\left(\log_2n+\log_2m\right)\right)\)

,通過樹剖消除了一個 \(\log\)

#pragma GCC optimize("Ofast")
#include<bits/stdc++.h>
using namespace std;
#define reg register
typedef long long ll;

bool st;

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
static char buf[1<<21],*p1=buf,*p2=buf;
#define flush() (fwrite(wbuf,1,wp1,stdout),wp1=0)
#define putchar(c) (wp1==wp2&&(flush(),0),wbuf[wp1++]=c)
static char wbuf[1<<21];int wp1;const int wp2=1<<21;
inline int read(void){
	reg char ch=getchar();
	reg int res=0;
	while(!isdigit(ch)) ch=getchar();
	while(isdigit(ch)) res=10*res+(ch^'0'),ch=getchar();
	return res;
}

inline void writeln(reg int x){
	static char buf[32];
	reg int p=-1;
	if(!x) putchar('0');
	else while(x) buf[++p]=(x%10)^'0',x/=10;
	while(~p) putchar(buf[p--]);
	putchar('\n');
	return;
}

inline void swap(reg int &x,reg int &y){
	reg int tmp=x;
	x=y,y=tmp;
	return;
}

const int MAXN=4e4+5;
const int MAXLOG2N=16+1;
const int MAXM=4e4+5;
const int MAXLOG2M=16+1;
const int MAXQ=4e4+5;

int n,m,q;

namespace Small{
	int cnt,head[MAXN],to[MAXN<<1],Next[MAXN<<1];
	inline void Add_Edge(reg int u,reg int v){
		Next[++cnt]=head[u];
		to[cnt]=v;
		head[u]=cnt;
		return;
	}
	inline void Add_Tube(reg int u,reg int v){
		Add_Edge(u,v),Add_Edge(v,u);
		return;
	}
	int fa[MAXN],dep[MAXN];
	int siz[MAXN],son[MAXN];
	inline void dfs1(reg int u,reg int father){
		siz[u]=1;
		fa[u]=father;
		dep[u]=dep[father]+1;
		for(reg int i=head[u];i;i=Next[i]){
			reg int v=to[i];
			if(v!=father){
				dfs1(v,u);
				if(siz[son[u]]<siz[v])
					son[u]=v;
			}
		}
		return;
	}
	int top[MAXN];
	inline void dfs2(reg int u,reg int father,reg int topf){
		top[u]=topf;
		if(!son[u])
			return;
		dfs2(son[u],u,topf);
		for(reg int i=head[u];i;i=Next[i]){
			reg int v=to[i];
			if(v!=father&&v!=son[u])
				dfs2(v,u,v);
		}
		return;
	}
	inline int LCA(reg int x,reg int y){
		while(top[x]!=top[y])
			if(dep[top[x]]>dep[top[y]])
				x=fa[top[x]];
			else
				y=fa[top[y]];
		return dep[x]<dep[y]?x:y;
	}
	inline int getDis(reg int x,reg int y){
		return dep[x]+dep[y]-(dep[LCA(x,y)]<<1);
	}
}

namespace Big{
	int cnt,head[MAXN],to[MAXN<<1],st[MAXN<<1],ed[MAXN<<1],Next[MAXN<<1];
	inline void Add_Edge(reg int u,reg int v,reg int s,reg int e){
		Next[++cnt]=head[u];
		to[cnt]=v,st[cnt]=s,ed[cnt]=e;
		head[u]=cnt;
		return;
	}
	inline void Add_Tube(reg int u,reg int v,reg int s,reg int e){
		Add_Edge(u,v,s,e),Add_Edge(v,u,e,s);
		return;
	}
	int fa[MAXM][MAXLOG2M],dep[MAXM];
	int dis[MAXM];
	int rt[MAXM],ptr[MAXM];
	inline void dfs(reg int u,reg int father,reg int e,reg int s){
		dep[u]=dep[father]+1;
		fa[u][0]=father;
		for(reg int i=1;(1<<i)<=dep[u];++i)
			fa[u][i]=fa[fa[u][i-1]][i-1];
		if(father)
			rt[u]=e,ptr[u]=s,dis[u]=dis[father]+Small::getDis(s,rt[father])+1;
		else
			rt[u]=1,ptr[u]=0,dis[u]=0;
		for(reg int i=head[u];i;i=Next[i]){
			reg int v=to[i];
			if(v!=father)
				dfs(v,u,ed[i],st[i]);
		}
		return;
	}
	inline int LCA(int x,int y){
		if(dep[x]>dep[y])
			swap(x,y);
		for(reg int i=MAXLOG2N-1;i>=0;--i)
			if(dep[fa[y][i]]>=dep[x])
				y=fa[y][i];
		if(x==y)
			return x;
		for(reg int i=MAXLOG2N-1;i>=0;--i)
			if(fa[x][i]!=fa[y][i])
				x=fa[x][i],y=fa[y][i];
		return fa[x][0];
	}
	inline pair<int,int> LCA_lower(int x,int y){
		if(dep[x]>dep[y])
			swap(x,y);
		for(reg int i=MAXLOG2N-1;i>=0;--i)
			if(dep[fa[y][i]]>dep[x])
				y=fa[y][i];
		if(fa[y][0]==x)
			return make_pair(y,0);
		if(dep[y]>dep[x])
			y=fa[y][0];
		for(reg int i=MAXLOG2N-1;i>=0;--i)
			if(fa[x][i]!=fa[y][i])
				x=fa[x][i],y=fa[y][i];
		return make_pair(x,y);
	}
}

bool ed;

int main(void){
	n=read(),m=read(),q=read();
	for(reg int i=1;i<n;++i){
		static int x,y;
		x=read(),y=read();
		Small::Add_Tube(x,y);
	}
	Small::dfs1(1,0),Small::dfs2(1,0,1);
	for(reg int i=1;i<m;++i){
		static int w,x,y,z;
		w=read(),x=read(),y=read(),z=read();
		Big::Add_Tube(w,y,x,z);
	}
	Big::dfs(1,0,1,0);

	/*
	puts("============");
	puts("Small:");
	for(reg int i=1;i<=n;++i)
		printf("i=%d fa=%d dep=%d\n",i,Small::fa[i][0],Small::dep[i]);
	puts("============");
	puts("Big:");
	for(reg int i=1;i<=m;++i)
		printf("i=%d fa=%d dep=%d dis=%lld rt=%d ptr=%d\n",i,Big::fa[i][0],Big::dep[i],Big::dis[i],Big::rt[i],Big::ptr[i]);
	puts("============");
	*/

	while(q--){
		static int w,x,y,z,part1,part2,part3,bLca;
		static pair<int,int> p;
		w=read(),x=read(),y=read(),z=read();
		//printf("query w=%d x=%d y=%d z=%d\n",w,x,y,z);
		if(w==y){
			//puts("S1");
			writeln(Small::getDis(x,z));
		}
		else{
			bLca=Big::LCA(w,y);
			if(bLca==w||bLca==y){
				//puts("S2");
				if(bLca==y)
					swap(w,y),swap(x,z);
				p=Big::LCA_lower(w,y);
				part1=Small::getDis(z,Big::rt[y]);
				part2=Big::dis[y]-Big::dis[p.first];
				part3=1+Small::getDis(Big::ptr[p.first],x);
				//printf("part1=%d part2=%d part3=%d\n",part1,part2,part3);
				writeln(part1+part2+part3);
			}
			else{
				//puts("S3");
				p=Big::LCA_lower(w,y);
				part1=Small::getDis(x,Big::rt[w])+Small::getDis(z,Big::rt[y]);
				part2=Big::dis[w]-Big::dis[p.first]+Big::dis[y]-Big::dis[p.second];
				part3=2+Small::getDis(Big::ptr[p.first],Big::ptr[p.second]);
				//printf("part1=%d part2=%d part3=%d\n",part1,part2,part3);
				writeln(part1+part2+part3);
			}
		}
	}
	flush();
	fprintf(stderr,"%.3lf s\n",1.0*clock()/CLOCKS_PER_SEC);
	fprintf(stderr,"%.3lf MiB\n",(&ed-&st)/1048576.0);
	return 0;
}