1. 程式人生 > 其它 >P8339-[AHOI2022]鑰匙【虛樹,掃描線】

P8339-[AHOI2022]鑰匙【虛樹,掃描線】

正題

題目連線:https://www.luogu.com.cn/problem/P8339


題目大意

給出\(n\)個點的一棵樹,每個點有鑰匙或者寶箱,有不同的顏色。

\(m\)次詢問,從\(x\)走到\(y\),走到鑰匙時會拾取鑰匙,走到寶箱時如果有同色的鑰匙那麼就會消耗一把鑰匙開啟寶箱,詢問能開啟多少個寶箱。

保證每一種顏色的鑰匙不超過\(5\)把。

\(1\leq n\leq 5\times 10^5,1\leq m\leq 10^6\)


解題思路

先考慮同色的寶箱和鑰匙都只有一個的情況,這是一個經典問題,假設分別為\(x,y\),那麼刪去\(x\leftrightarrow y\)的路徑,\(x\)

的聯通塊記為\(S\)\(y\)的聯通塊記為\(T\)

如果詢問節點起點在\(S\),終點在\(T\)就會產生貢獻。

那麼\(S\)\(T\)要麼兩個都是子樹,要麼一個是子樹,另一個是整棵樹刪去一個子樹,也就是說它們都可以表示成\(dfs\)序上的一個或兩個連續區間。

那麼我們把兩個區間視為一個二維平面上的正方形\(+1\),然後詢問的視為查詢一個點的值,實現方法就是把這些都離線下來用掃描線。

好現在考慮這一題,我們會發現一條路徑上我們把單種顏色的拿出來,鑰匙視為\((\),寶箱視為\()\),那麼就是一個類似括號匹配的東西,每一對產生貢獻的點都會滿足中間是一個合法的括號序。

那麼我們從這個性質入手,我們列舉所有顏色,把同色的點建一棵虛樹,對於每個鑰匙我們暴力掃全圖,能找到很多個合法的貢獻對\(x,y\)

,像上面的方法掃描線就好了。

實際上我們會發現這樣枚舉出來的貢獻對其實是\(n\)個而不是\(5n\)個的。

時間複雜度:\(O((n+m)\log n)\)


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#include<vector>
#include<stack>
#define mp(x,y) make_pair(x,y)
#define lowbit(x) (x&-x)
using namespace std;
const int N=5e5+10;
struct node{
	int to,next;
}a[N<<1];
int n,m,tot,Top,cnt,ls[N],t[N],c[N],s[N],ans[N];
int siz[N],dep[N],son[N],fa[N],top[N],dfn[N],rfn[N],ed[N];
vector<int> G[N],p[N];stack<int> cl;
vector<pair<int,int> >I[N],O[N],q[N];
void addl(int x,int y){
	a[++tot].to=y;
	a[tot].next=ls[x];
	ls[x]=tot;return;
}
bool cmp(int x,int y)
{return rfn[x]<rfn[y];}
void dfs(int x){
	siz[x]=1;dep[x]=dep[fa[x]]+1;
	for(int i=ls[x];i;i=a[i].next){
		int y=a[i].to;
		if(y==fa[x])continue;
		fa[y]=x;dfs(y);siz[x]+=siz[y];
		if(siz[y]>siz[son[x]])son[x]=y;
	}
	return;
}
void dfs2(int x){
	dfn[++cnt]=x;rfn[x]=cnt;
	if(son[x]){
		top[son[x]]=top[x];
		dfs2(son[x]);
	}
	for(int i=ls[x];i;i=a[i].next){
		int y=a[i].to;
		if(y==fa[x]||y==son[x])continue;
		top[y]=y;dfs2(y);
	}
	ed[x]=cnt;return;
}
int LCA(int x,int y){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]])
			swap(x,y);
		x=fa[top[x]];
	}
	return (dep[x]<dep[y])?x:y;
}
int getTop(int x,int y){
	while(top[y]!=top[x])
		if(fa[top[y]]==x)
			return top[y];
		else y=fa[top[y]];
	return dfn[rfn[x]+1];
}
void addG(int x,int y){
	G[x].push_back(y);
	G[y].push_back(x);
	cl.push(x);cl.push(y);
	return;
}
void Clear(){
	Top=0;
	while(!cl.empty())
	{G[cl.top()].clear();cl.pop();}
}
void Ins(int x){
	if(!Top){s[++Top]=x;return;}
	int lca=LCA(s[Top],x);
	while(Top>1&&dep[s[Top-1]]>=dep[lca])
		addG(s[Top-1],s[Top]),Top--;
	if(dep[s[Top]]>dep[lca])
		addG(lca,s[Top]),Top--;
	if(s[Top]!=lca)s[++Top]=lca;
	s[++Top]=x;return;
}
void Build(vector<int> &p){
	sort(p.begin(),p.end(),cmp);
	if(p[0]!=1)Ins(1);
	for(int i=0;i<p.size();i++)Ins(p[i]);
	while(Top>1)addG(s[Top-1],s[Top]),Top--;
}
void Sets(int x,int y){
	int lca=LCA(x,y);
	if(lca==x){
		x=getTop(x,y);
		I[1].push_back(mp(rfn[y],ed[y]));
		O[rfn[x]].push_back(mp(rfn[y],ed[y]));
		I[ed[x]+1].push_back(mp(rfn[y],ed[y]));
	}
	else if(lca==y){
		y=getTop(y,x);
		if(rfn[y]>1)I[rfn[x]].push_back(mp(1,rfn[y]-1));
		if(ed[y]<n)I[rfn[x]].push_back(mp(ed[y]+1,n));
		if(rfn[y]>1)O[ed[x]+1].push_back(mp(1,rfn[y]-1));
		if(ed[y]<n)O[ed[x]+1].push_back(mp(ed[y]+1,n));
	}
	else{
		I[rfn[x]].push_back(mp(rfn[y],ed[y]));
		O[ed[x]+1].push_back(mp(rfn[y],ed[y]));
	}
	return;
}
void calc(int x,int fa,int k,int &from,int &_){
	if(c[x]==-_){k++;}
	if(c[x]==_){
		k--;
		if(!k){
			Sets(from,x);
			return;
		}
	}
	for(int i=0;i<G[x].size();i++)
		if(G[x][i]!=fa)calc(G[x][i],x,k,from,_);
}
void Change(int x,int val){
	while(x<=n){
		t[x]+=val;
		x+=lowbit(x);
	}
	return;
}
int Ask(int x){
	int ans=0;
	while(x){
		ans+=t[x];
		x-=lowbit(x);
	}
	return ans;
}
int main()
{
	scanf("%d%d",&n,&m);
	for(int i=1,t;i<=n;i++){
		scanf("%d%d",&t,&c[i]);
		p[c[i]].push_back(i);
		if(t==1)c[i]=-c[i];
	}
	for(int i=1,x,y;i<n;i++){
		scanf("%d%d",&x,&y);
		addl(x,y);addl(y,x);
	}
	dfs(1);dfs2(1);
	for(int _=1;_<=n;_++){
		if(p[_].empty())continue;
		Build(p[_]);
		for(int i=0;i<p[_].size();i++)
			if(c[p[_][i]]==-_)
				calc(p[_][i],0,0,p[_][i],_);
		Clear();
	}
	for(int i=1,x,y;i<=m;i++)
		scanf("%d%d",&x,&y),q[rfn[x]].push_back(mp(rfn[y],i));
	for(int i=1;i<=n;i++){
		for(int j=0;j<I[i].size();j++)
			Change(I[i][j].first,1),Change(I[i][j].second+1,-1);
		for(int j=0;j<O[i].size();j++)
			Change(O[i][j].first,-1),Change(O[i][j].second+1,1);
		for(int j=0;j<q[i].size();j++)
			ans[q[i][j].second]=Ask(q[i][j].first);
	}
	for(int i=1;i<=m;i++)
		printf("%lld\n",ans[i]);
	return 0;
}