P8339-[AHOI2022]鑰匙【虛樹,掃描線】
阿新 • • 發佈:2022-05-20
正題
題目連線:https://www.luogu.com.cn/problem/P8339
題目大意
給出\(n\)個點的一棵樹,每個點有鑰匙或者寶箱,有不同的顏色。
\(m\)次詢問,從\(x\)走到\(y\),走到鑰匙時會拾取鑰匙,走到寶箱時如果有同色的鑰匙那麼就會消耗一把鑰匙開啟寶箱,詢問能開啟多少個寶箱。
保證每一種顏色的鑰匙不超過\(5\)把。
\(1\leq n\leq 5\times 10^5,1\leq m\leq 10^6\)
解題思路
先考慮同色的寶箱和鑰匙都只有一個的情況,這是一個經典問題,假設分別為\(x,y\),那麼刪去\(x\leftrightarrow y\)的路徑,\(x\)
如果詢問節點起點在\(S\),終點在\(T\)就會產生貢獻。
那麼\(S\)和\(T\)要麼兩個都是子樹,要麼一個是子樹,另一個是整棵樹刪去一個子樹,也就是說它們都可以表示成\(dfs\)序上的一個或兩個連續區間。
那麼我們把兩個區間視為一個二維平面上的正方形\(+1\),然後詢問的視為查詢一個點的值,實現方法就是把這些都離線下來用掃描線。
好現在考慮這一題,我們會發現一條路徑上我們把單種顏色的拿出來,鑰匙視為\((\),寶箱視為\()\),那麼就是一個類似括號匹配的東西,每一對產生貢獻的點都會滿足中間是一個合法的括號序。
那麼我們從這個性質入手,我們列舉所有顏色,把同色的點建一棵虛樹,對於每個鑰匙我們暴力掃全圖,能找到很多個合法的貢獻對\(x,y\)
實際上我們會發現這樣枚舉出來的貢獻對其實是\(n\)個而不是\(5n\)個的。
時間複雜度:\(O((n+m)\log n)\)
code
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> #include<stack> #define mp(x,y) make_pair(x,y) #define lowbit(x) (x&-x) using namespace std; const int N=5e5+10; struct node{ int to,next; }a[N<<1]; int n,m,tot,Top,cnt,ls[N],t[N],c[N],s[N],ans[N]; int siz[N],dep[N],son[N],fa[N],top[N],dfn[N],rfn[N],ed[N]; vector<int> G[N],p[N];stack<int> cl; vector<pair<int,int> >I[N],O[N],q[N]; void addl(int x,int y){ a[++tot].to=y; a[tot].next=ls[x]; ls[x]=tot;return; } bool cmp(int x,int y) {return rfn[x]<rfn[y];} void dfs(int x){ siz[x]=1;dep[x]=dep[fa[x]]+1; for(int i=ls[x];i;i=a[i].next){ int y=a[i].to; if(y==fa[x])continue; fa[y]=x;dfs(y);siz[x]+=siz[y]; if(siz[y]>siz[son[x]])son[x]=y; } return; } void dfs2(int x){ dfn[++cnt]=x;rfn[x]=cnt; if(son[x]){ top[son[x]]=top[x]; dfs2(son[x]); } for(int i=ls[x];i;i=a[i].next){ int y=a[i].to; if(y==fa[x]||y==son[x])continue; top[y]=y;dfs2(y); } ed[x]=cnt;return; } int LCA(int x,int y){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); x=fa[top[x]]; } return (dep[x]<dep[y])?x:y; } int getTop(int x,int y){ while(top[y]!=top[x]) if(fa[top[y]]==x) return top[y]; else y=fa[top[y]]; return dfn[rfn[x]+1]; } void addG(int x,int y){ G[x].push_back(y); G[y].push_back(x); cl.push(x);cl.push(y); return; } void Clear(){ Top=0; while(!cl.empty()) {G[cl.top()].clear();cl.pop();} } void Ins(int x){ if(!Top){s[++Top]=x;return;} int lca=LCA(s[Top],x); while(Top>1&&dep[s[Top-1]]>=dep[lca]) addG(s[Top-1],s[Top]),Top--; if(dep[s[Top]]>dep[lca]) addG(lca,s[Top]),Top--; if(s[Top]!=lca)s[++Top]=lca; s[++Top]=x;return; } void Build(vector<int> &p){ sort(p.begin(),p.end(),cmp); if(p[0]!=1)Ins(1); for(int i=0;i<p.size();i++)Ins(p[i]); while(Top>1)addG(s[Top-1],s[Top]),Top--; } void Sets(int x,int y){ int lca=LCA(x,y); if(lca==x){ x=getTop(x,y); I[1].push_back(mp(rfn[y],ed[y])); O[rfn[x]].push_back(mp(rfn[y],ed[y])); I[ed[x]+1].push_back(mp(rfn[y],ed[y])); } else if(lca==y){ y=getTop(y,x); if(rfn[y]>1)I[rfn[x]].push_back(mp(1,rfn[y]-1)); if(ed[y]<n)I[rfn[x]].push_back(mp(ed[y]+1,n)); if(rfn[y]>1)O[ed[x]+1].push_back(mp(1,rfn[y]-1)); if(ed[y]<n)O[ed[x]+1].push_back(mp(ed[y]+1,n)); } else{ I[rfn[x]].push_back(mp(rfn[y],ed[y])); O[ed[x]+1].push_back(mp(rfn[y],ed[y])); } return; } void calc(int x,int fa,int k,int &from,int &_){ if(c[x]==-_){k++;} if(c[x]==_){ k--; if(!k){ Sets(from,x); return; } } for(int i=0;i<G[x].size();i++) if(G[x][i]!=fa)calc(G[x][i],x,k,from,_); } void Change(int x,int val){ while(x<=n){ t[x]+=val; x+=lowbit(x); } return; } int Ask(int x){ int ans=0; while(x){ ans+=t[x]; x-=lowbit(x); } return ans; } int main() { scanf("%d%d",&n,&m); for(int i=1,t;i<=n;i++){ scanf("%d%d",&t,&c[i]); p[c[i]].push_back(i); if(t==1)c[i]=-c[i]; } for(int i=1,x,y;i<n;i++){ scanf("%d%d",&x,&y); addl(x,y);addl(y,x); } dfs(1);dfs2(1); for(int _=1;_<=n;_++){ if(p[_].empty())continue; Build(p[_]); for(int i=0;i<p[_].size();i++) if(c[p[_][i]]==-_) calc(p[_][i],0,0,p[_][i],_); Clear(); } for(int i=1,x,y;i<=m;i++) scanf("%d%d",&x,&y),q[rfn[x]].push_back(mp(rfn[y],i)); for(int i=1;i<=n;i++){ for(int j=0;j<I[i].size();j++) Change(I[i][j].first,1),Change(I[i][j].second+1,-1); for(int j=0;j<O[i].size();j++) Change(O[i][j].first,-1),Change(O[i][j].second+1,1); for(int j=0;j<q[i].size();j++) ans[q[i][j].second]=Ask(q[i][j].first); } for(int i=1;i<=m;i++) printf("%lld\n",ans[i]); return 0; }