P5327-[ZJOI2019]語言【線段樹合併,LCA】
阿新 • • 發佈:2021-11-30
正題
題目連結:https://www.luogu.com.cn/problem/P5327
題目大意
給出\(n\)個點的一棵樹,和\(m\)條路徑,求有多少個點對至少存在一條路徑經過它們。
\(1\leq n,m\leq 10^5\)
解題思路
有一個很顯然的性質,如果點\(z\)在\(x\rightarrow y\)的路徑上,並且\((x,z)\)不合法,那麼\((x,y)\)肯定不合法。
所以這樣的對於一個節點\(x\)來說所有和它合法的點會形成一棵生成樹,這棵生成樹是哪來的也很好說,我們把所有經過\(x\)的路徑\(s\rightarrow t\)的\(s\)和\(t\)存下來,構出一棵虛樹,這棵虛樹的大小就是對於這個點來說合法的點個數。
虛樹的大小怎麼求,這個方法很多,而我們儘量使用一種比較方便動態維護的方法,把所有點按照\(dfs\)序排序,假設節點\(x\)之後排的是\(y\)(定義第一個之前排的是最後一個),那麼就是所有\(dep_{y}-dep_{lca(x,y)}\)的和。
不難發現這個東西可以在序列上維護,同理的可以線上段樹上維護,每個區間記錄\(dfs\)最小的和最大的點就好了。
那麼做法已經很顯然了,一條路徑經過的點我們可以用樹上差分來做到也就是\(s\leftrightarrow lca(s,t)\)和\(t\leftrightarrow lca(s,t)\)的部分。
然後兩個子樹的資訊合併的時候用線段樹合併就好了。
寫了個\(RMQ\)來快速求\(LCA\)。
時間複雜度:\(O(n\log n)\)
code
#include<cstdio> #include<cstring> #include<algorithm> #include<vector> using namespace std; const int N=1e5+10,M=N*20*4; struct node{ int to,next; }a[N<<1]; int n,m,tot,cnt,dfc,ls[N],dep[N],rt[N]; int dfn[N],rfn[N],rgn[N],lg[N<<1],f[N<<1][19]; vector<int> v[N],g[N]; long long ans; void addl(int x,int y){ a[++tot].to=y; a[tot].next=ls[x]; ls[x]=tot;return; } void dfs(int x,int fa){ dep[x]=dep[fa]+1; dfn[++dfc]=x;rfn[x]=dfc; rgn[x]=++cnt;f[cnt][0]=x; for(int i=ls[x];i;i=a[i].next){ int y=a[i].to; if(y==fa)continue; dfs(y,x);f[++cnt][0]=x; } return; } int LCA(int x,int y){ int l=rgn[x],r=rgn[y]; if(l>r)swap(l,r); int z=lg[r-l+1]; x=f[l][z];y=f[r-(1<<z)+1][z]; return (dep[x]<dep[y])?x:y; } int calc(int x,int y){ if(!x||!y)return 0; return dep[y]-dep[LCA(x,y)]; } struct SegTree{ int cnt,w[M],s[M],lp[M],rp[M],ls[M],rs[M]; void Merge(int x,int ls,int rs){ s[x]=s[ls]+s[rs]+calc(rp[ls],lp[rs]); lp[x]=lp[ls]?lp[ls]:lp[rs]; rp[x]=rp[rs]?rp[rs]:rp[ls]; return; } void Change(int &x,int L,int R,int pos,int val){ if(!x)x=++cnt; if(L==R){ w[x]+=val; if(w[x])lp[x]=rp[x]=dfn[pos]; else lp[x]=rp[x]=0; return; } int mid=(L+R)>>1; if(pos<=mid)Change(ls[x],L,mid,pos,val); else Change(rs[x],mid+1,R,pos,val); Merge(x,ls[x],rs[x]);return; } int Merge(int x,int y,int L,int R){ if(!x||!y)return x+y; if(L==R){ w[x]=w[x]+w[y]; if(w[x])lp[x]=rp[x]=dfn[L]; else lp[x]=rp[x]=0; return x; } int mid=(L+R)>>1; ls[x]=Merge(ls[x],ls[y],L,mid); rs[x]=Merge(rs[x],rs[y],mid+1,R); Merge(x,ls[x],rs[x]);return x; } }T; void solve(int x,int fa){ for(int i=ls[x];i;i=a[i].next){ int y=a[i].to; if(y==fa)continue;solve(y,x); rt[x]=T.Merge(rt[x],rt[y],1,n); } for(int i=0;i<v[x].size();i++) T.Change(rt[x],1,n,v[x][i],1); ans+=T.s[rt[x]]+dep[T.lp[rt[x]]]-dep[LCA(T.lp[rt[x]],T.rp[rt[x]])]+1; for(int i=0;i<g[x].size();i++) T.Change(rt[x],1,n,g[x][i],-2); return; } int main() { scanf("%d%d",&n,&m); for(int i=1,x,y;i<n;i++){ scanf("%d%d",&x,&y); addl(x,y);addl(y,x); } dfs(1,0); for(int j=1;(1<<j)<=cnt;j++) for(int i=1;i+(1<<j)-1<=cnt;i++){ int x=f[i][j-1],y=f[i+(1<<j-1)][j-1]; f[i][j]=(dep[x]<dep[y])?x:y; } for(int i=2;i<=cnt;i++)lg[i]=lg[i>>1]+1; for(int i=1,x,y;i<=m;i++){ scanf("%d%d",&x,&y); v[x].push_back(rfn[x]); v[x].push_back(rfn[y]); v[y].push_back(rfn[x]); v[y].push_back(rfn[y]); int lca=LCA(x,y); g[lca].push_back(rfn[x]); g[lca].push_back(rfn[y]); } solve(1,0); printf("%lld\n",(ans-n)/2ll); return 0; }