芝士:樹上莫隊
樹上莫隊
將區間的莫隊演算法拓展到樹上,以此來解決一些樹上路徑的問題
思路
考慮為什麼普通莫隊為什麼只是排一個序就可以把暴力的時間複雜度除以\(\sqrt n\)?
其原因是儲存了之前的結果,也就是指對於重複的部分不需要多餘的空間,只需要調整詢問的順序就可以在整體上得到最大的優化。
樹上莫隊也是如此,但是我們要考慮如何將一顆樹轉換成序列的形式,並且所有的操作都需要變成對一個連續區間的詢問
樹上的路徑在一般情況下可以分成二種,一種是一個節點是另一個節點的\(LCA\),另一種情況就是這兩個節點的\(LCA\)是第三個點
兩種情況都面臨一個共同的困難,如何判斷一個點是不是在路徑上?尤拉序就派上了用場,如果一個點不在我們需要的路徑上,那麼我們對操作所轉換成為的連續區間中一定包含兩個這個點(先不考慮怎麼轉換的區間),但同時,我們所選擇的區間一定也要包含目標路徑,所有無用的節點也必須出現兩次
假設\(st[u]<st[v]\)
如果是第一種情況,很明顯的,選擇\(st[u]\sim st[v]\),
如果是第二種情況,選擇\(ed[u]\sim st[v]\),想象一下我們尤拉序是怎麼構造出來的,因為\(lca\)一定在\(u\)的子樹外,所以要選擇\(ed[u]\),但同時的,考慮到不能包含到\(v\)節點子樹內的節點,所以要選擇\(st[u]\),但是同時我們會把\(LCA\)忽略掉,不過我們也只會忽略掉\(LCA\),考慮兩條路徑\(u\sim LCA\)和\(LCA\sim v\),對於\(u\sim LCA\)這條路徑,因為我們是在往上走,所以對於所有的有用的節點,只會出現1次,如果是其他節點,一定會先進入再出來,所以會出現兩次,對於\(LCA\sim v\)
因為我們實際上把一個節點數為\(n\)的數變成了一個長度為\(2n\)的序列,同樣的,塊的大小要從\(\sqrt n\)變成\(\sqrt {2n}\)
所以時間複雜度依然是\(O(n\sqrt n)\)
例題
並沒有什麼特別的
#include<iostream> #include<algorithm> #include<vector> #include<cstdio> #include<cmath> #include<map> using namespace std; int divi; struct pr { int l,r; int lca,id; friend bool operator < (const pr &a,const pr &b) { return (a.l/divi==b.l/divi?a.r<b.r:a.l<b.l); } }p[100005]; int n,m,cnt,tot,cal; int dp[100005][25],dep[100005],st[100005],ed[100005]; int val[100005],ans[100005],id[200005]; vector<int> g[100005]; map<int,int> h; int l,r; int t[100005]; bool used[100005]; void dfs(int u,int fa) { st[u]=++cnt; id[cnt]=u; dep[u]=dep[fa]+1; for(int i=1;i<=20;i++) dp[u][i]=dp[dp[u][i-1]][i-1]; for(int i=0;i<g[u].size();i++) { int v=g[u][i]; if(v!=fa) { dp[v][0]=u; dfs(v,u); } } ed[u]=++cnt; id[cnt]=u; } int lca(int u,int v) { if(dep[u]>dep[v]) swap(u,v); for(int i=20;i>=0;i--) if(dep[u]<=dep[dp[v][i]]) v=dp[v][i]; if(u==v) return u; for(int i=20;i>=0;i--) if(dp[u][i]!=dp[v][i]) { u=dp[u][i]; v=dp[v][i]; } return dp[u][0]; } void add(int pos) { t[val[pos]]++; if(t[val[pos]]==1) cal++; } void del(int pos) { t[val[pos]]--; if(t[val[pos]]==0) cal--; } void calc(int pos) { used[pos]?del(pos):add(pos); used[pos]^=1; } int main() { ios::sync_with_stdio(false); cin>>n>>m; divi=sqrt(2*n); for(int i=1;i<=n;i++) { cin>>val[i]; if(h.count(val[i])==0) h[val[i]]=++tot; val[i]=h[val[i]]; } for(int i=1,u,v;i<n;i++) { cin>>u>>v; g[u].push_back(v); g[v].push_back(u); } dfs(1,0); for(int i=1,u,v;i<=m;i++) { cin>>u>>v; if(st[u]>st[v]) swap(u,v); int t=lca(u,v); if(u==t) { p[i].l=st[u]; p[i].r=st[v]; p[i].lca=0; } else { p[i].l=ed[u]; p[i].r=st[v]; p[i].lca=t; } p[i].id=i; } sort(p+1,p+m+1); l=r=cal=1; t[val[id[1]]]++; used[val[id[1]]]=1; for(int i=1;i<=m;i++) { while(r<p[i].r) { r++; calc(id[r]); } while(p[i].r<r) { calc(id[r]); r--; } while(l<p[i].l) { calc(id[l]); l++; } while(p[i].l<l) { l--; calc(id[l]); } if(p[i].lca) calc(p[i].lca); ans[p[i].id]=cal; if(p[i].lca) calc(p[i].lca); } for(int i=1;i<=m;i++) cout<<ans[i]<<'\n'; return 0; }