1. 程式人生 > 其它 ><題解>世界樹

<題解>世界樹

世界樹<題解>

首先我們拿到這個題之後,能想到的一定是虛樹,如果想不到的話,還是重新學一遍去吧

所以我們應該怎麼做呢

虛樹的板子不需要我再講一遍了吧

所以對於這個題來說,怎麼根據虛樹上的節點來找到每一個點的集合的大小

針對這種在樹上求集合大小的題,不是dp就是利用siz(子樹大小)來容斥求得

然而這個題就是巧妙的利用了siz這個東西

(注意接下來說的,兒子是指虛樹中的兒子,子樹是指原樹中的子樹)

vis[]陣列用來標記這個節點是否為議事處

我們利用who[]和dis[]兩個陣列來實現在這個虛樹上的容斥

who[x]表示距離x這個節點最近的被標記的節點

dis[x]表示x與who[x]之間的距離

我們想要求這兩個值,就需要分為兩部分去求,父親那邊的,虛樹子樹內的

所以我們進行兩次dfs,一次向下尋找,一次向上尋找,得到這個陣列

int who[N],dis[N];
void dfs_sol1(int x){
	who[x]=dis[x]=inf;
	if(vis[x])who[x]=x,dis[x]=0;
	for(re i=head[x];i;i=nxt[i]){
		int y=to[i];
		dfs_sol1(y);
		int tmp=DIS(x,who[y]);
		if(tmp<dis[x]||tmp==dis[x]&&who[y]<who[x])who[x]=who[y],dis[x]=tmp;
	}
}
void dfs_sol2(int x){
	for(re i=head[x];i;i=nxt[i]){
		int y=to[i];
		int tmp=DIS(y,who[x]);
		if(tmp<dis[y]||tmp==dis[y]&&who[x]<who[y])who[y]=who[x],dis[y]=tmp;
		dfs_sol2(y);
	}
}

為什麼不需要分為兩種,父親的和兒子的

因為我們對於每一個節點,他和他兒子的who不一樣的話

要麼他兒子本身就被標記過了,要麼是兒子的who在兒子的虛樹子樹內

所以不會出現有兩組x到who[x]的路徑重複的情況

所以我們只需要記錄一個最近節點

那麼最後我們就需要去求每個議事處的集合大小了

既然是在樹上,我們很容易想到是要用到一個dfs的

那我們每遍歷到一個節點就將這個節點的siz加到這個節點的who上

然後我們就進行一次判斷,

     如果此時節點與他的兒子節點的who相同,那我們就直接減去兒子節點的siz

     如果不相同的話,我們就利用倍增的辦法,找到這個分界點,使得分界點兩側一側是分到
     當前節點的who上,一側是分到兒子節點的who上,這樣我們直接利用siz進行加減就好了

如何找到分界點呢?

因為每兩個點之間的距離都是1,所以這個分界點一定是這條鏈的中點,再特判一下,搞定中點就可以啦

為什麼要用倍增??因為快

$ code $



#include<bits/stdc++.h>
using namespace std;
#define re register int 
#define ll long long
#define inf 0x3f3f3f3f
const int N=300005;
int n,q,m,h[N];
int to[N*2],nxt[N*2],head[N],rp;
int dfn[N],cnt;
int dep[N],fa[N][21],siz[N];
void add_edg(int x,int y){
	to[++rp]=y;
	nxt[rp]=head[x];
	head[x]=rp;
}
void dfs_first(int x){
	dfn[x]=++cnt;
	siz[x]=1;
	for(re i=head[x];i;i=nxt[i]){
		int y=to[i];
		if(y==fa[x][0])continue;
		fa[y][0]=x;
		for(re i=1;i<=20;i++)fa[y][i]=fa[fa[y][i-1]][i-1];//cout<<fa[y][i]<<endl;
		dep[y]=dep[x]+1;
		dfs_first(y);
		siz[x]+=siz[y];
	}
}
int LCA(int x,int y){
	if(dep[x]<dep[y])swap(x,y);
	for(re i=20;i>=0;i--)
		if(dep[fa[x][i]]>=dep[y])
			x=fa[x][i];
	if(x==y)return x;
	for(re i=20;i>=0;i--)
		if(fa[x][i]!=fa[y][i])
			x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}
int DIS(int x,int y){
	return dep[x]+dep[y]-2*dep[LCA(x,y)];
}
int sta[N],tot;
bool vis[N];
bool cmp(int x,int y){
	return dfn[x]<dfn[y];
}
void build_vtree(){
	sort(h+1,h+m+1,cmp);
	sta[tot=1]=1;head[1]=0;rp=0;
	for(re i=1;i<=m;i++){
		if(h[i]==1)continue;
		int lca=LCA(sta[tot],h[i]);
		//cout<<lca<<endl;
		if(lca!=sta[tot]){
			while(dfn[lca]<dfn[sta[tot-1]])
				add_edg(sta[tot-1],sta[tot]),tot--;
			if(lca!=sta[tot-1]){
				head[lca]=0;
				add_edg(lca,sta[tot]);
				sta[tot]=lca;
			}
			else add_edg(lca,sta[tot]),tot--;
		}
		head[h[i]]=0;
		sta[++tot]=h[i];
	}
	//cout<<tot<<endl;
	for(re i=1;i<tot;i++)add_edg(sta[i],sta[i+1]);
}
int who[N],dis[N];
void dfs_sol1(int x){
	who[x]=dis[x]=inf;
	if(vis[x])who[x]=x,dis[x]=0;
	for(re i=head[x];i;i=nxt[i]){
		int y=to[i];
		dfs_sol1(y);
		int tmp=DIS(x,who[y]);
		if(tmp<dis[x]||tmp==dis[x]&&who[y]<who[x])who[x]=who[y],dis[x]=tmp;
	}
}
void dfs_sol2(int x){
	for(re i=head[x];i;i=nxt[i]){
		int y=to[i];
		int tmp=DIS(y,who[x]);
		if(tmp<dis[y]||tmp==dis[y]&&who[x]<who[y])who[y]=who[x],dis[y]=tmp;
		dfs_sol2(y);
	}
}
int get_fa(int x,int t){
	for(re i=20;i>=0;i--)
		if(t>=1<<i)
			x=fa[x][i],t-=1<<i;
	return x;
}
int ans[N];
void dfs_ans(int x){
	ans[who[x]]+=siz[x];
	for(re i=head[x];i;i=nxt[i]){
		int y=to[i];
		//cout<<x<<" "<<who[x]<<" "<<y<<" "<<who[y]<<endl;
		if(who[x]==who[y])ans[who[x]]-=siz[y];
		else{
			int tmp=DIS(x,y)-1+dis[x]-dis[y]>>1;
			if(!(DIS(who[x],who[y])&1)&&who[y]<who[x])tmp++;
			int z=get_fa(y,tmp);
			ans[who[x]]-=siz[z];
			ans[who[y]]+=siz[z]-siz[y];
		}
		dfs_ans(y);
	}
}
signed main(){
	scanf("%d",&n);
	for(re i=1,x,y;i<n;i++){
		scanf("%d%d",&x,&y);
		add_edg(x,y);
		add_edg(y,x);
	}
	dep[1]=1;
	dfs_first(1);
	scanf("%d",&q);
	int an[N];
	while(q--){
		scanf("%d",&m);
		for(re i=1;i<=m;i++){
			scanf("%d",&h[i]);
			an[i]=h[i];
			vis[h[i]]=1;
		}
		build_vtree();
		dfs_sol1(1);
		dfs_sol2(1);
		for(re i=1;i<=m;i++)ans[an[i]]=0;
		dfs_ans(1);
		for(re i=1;i<=m;i++)printf("%d ",ans[an[i]]);
		printf("\n");
		for(re i=1;i<=m;i++)vis[h[i]]=0;
	}
}

完事了。。。。。