SPOJ COT2 樹上的莫隊演算法,樹上區間查詢
阿新 • • 發佈:2019-02-14
題意:n個節點形成的一棵樹。每個節點有一個值。m次查詢,求出(u,v)路徑上出現了多少個不同的數。
樹上的莫隊演算法,同樣將樹分成siz=sqrt(n)塊,然後離線操作。先對樹dfs一遍,每當子樹節點個數num>=siz,就將這num個分成一塊。讀取所有的查詢按左端點所在塊排序。
用倍增法求lca單次要用logn複雜度,要跑3200ms。有個地方可以優化,就是知道了所有的查詢,也就是事先知道了轉移路徑,可以用離線的方法求O(n)求出所有需要用到的lca,這個寫起來比較麻煩,不過可以優化到1800ns。程式碼寫的比較挫。。。。
logn求lca:3200+ms
#include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <cctype> #include <string> #include <vector> #include <map> #include <set> #include <vector> #include <queue> #include <stack> #include <algorithm> using namespace std; const int maxn=4e4+10; const int maxm=1e5+10; int n,m, siz; vector<int> g[maxn]; int a[maxn], b[maxn], ans[maxm]; int tot[maxn], in[maxn]; int fa[maxn][20], dep[maxn]; struct Query { int l, r, id; int st,ed; bool operator <(const Query& a) const { return st!=a.st? st<a.st: ed<a.ed; //先按左端點所在塊先後排序,其次考慮又右端點所在塊 } }; Query q[maxm]; int tag, bel[maxn]; int st[maxn], top; int dfs(int u, int par, int d, int &cnt) { dep[u]=d; fa[u][0]=par; int num=0; for(int i=0; i<g[u].size(); i++){ int v=g[u][i]; if(v!=par){ num+=dfs(v, u, d+1, cnt); if(num>=siz){ //子樹大小>=sqrt(n),分成一塊 for(int i=0; i<num; i++) bel[st[--top]]=tag; tag++; num=0; } } } st[top++]=u;//記錄子樹遍歷的點 return num+1; } void init() { for(int i=0; i<=n; i++) g[i].clear(); memset(tot, 0, sizeof(tot)); memset(in, 0, sizeof(in)); siz=sqrt(n); for(int i=1;i<=n; i++) scanf("%d",&a[i]), b[i]=a[i]; sort(b+1, b+n+1); for(int i=1; i<=n; i++) a[i]=lower_bound(b+1, b+n+1, a[i])-b; for(int i=0; i<n-1; i++){ int u,v; scanf("%d%d", &u, &v); g[u].push_back(v); g[v].push_back(u); } int cnt=0; tag=top=0; int num=dfs(1, -1, 0, cnt); for(int i=0; i<num; i++) bel[st[--top]]=tag; //最後剩下的數也分成一塊 for(int i=1; i<20; i++){ for(int u=1; u<=n; u++) if(fa[u][i-1]==-1) fa[u][i]=-1; else fa[u][i]=fa[fa[u][i-1]][i-1]; } for(int i=0; i<m; i++){ scanf("%d%d", &q[i].l, &q[i].r); if(bel[q[i].l]>bel[q[i].r]) swap(q[i].l, q[i].r); q[i].id=i; q[i].st=bel[q[i].l]; q[i].ed=bel[q[i].r]; } sort(q, q+m); } int lca(int u, int v) { if(dep[u]>dep[v]) swap(u, v); for(int i=0; i<20; i++) if((dep[v]-dep[u])>>i&1) v=fa[v][i]; if(u==v) return u; for(int i=19; i>=0; i--){ if(fa[u][i]!=fa[v][i]){ u=fa[u][i]; v=fa[v][i]; } } return fa[u][0]; } void solve() { int res=0; int cu=1, cv=1; for(int i=0; i<m; i++){ int nu=q[i].l, nv=q[i].r; int par=lca(cu, nu); while(cu!=par){ if(in[cu]){ if(--tot[a[cu]]==0) res--; } else if(++tot[a[cu]]==1) res++; in[cu]^=1; cu=fa[cu][0]; } cu=nu; while(cu!=par){ if(in[cu]){ if(--tot[a[cu]]==0) res--; } else if(++tot[a[cu]]==1) res++; in[cu]^=1; cu=fa[cu][0]; } cu=nu; par=lca(cv, nv); while(cv!=par){ if(in[cv]){ if(--tot[a[cv]]==0) res--; } else if(++tot[a[cv]]==1) res++; in[cv]^=1; cv=fa[cv][0]; } cv=nv; while(cv!=par){ if(in[cv]){ if(--tot[a[cv]]==0) res--; } else if(++tot[a[cv]]==1) res++; in[cv]^=1; cv=fa[cv][0]; } cv=nv; par=lca(cu, cv); ans[q[i].id]=res+(!tot[a[par]]); } } int main() { while(scanf("%d%d", &n, &m)==2){ init(); solve(); for(int i=0; i<m; i++) printf("%d\n", ans[i]); } return 0; }
離線查詢lca:1800+ms
#include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <cctype> #include <string> #include <vector> #include <map> #include <set> #include <vector> #include <queue> #include <stack> #include <algorithm> using namespace std; #pragma comment(linker, "/STACK:1024000000,1024000000") typedef pair<int,int> P; #define fir first #define sec second const int maxn=4e4+10; const int maxm=1e5+10; int n,m, siz; vector<int> g[maxn]; int first[maxn],ltot=0, nxt[6*maxm]; P lq[6*maxm];//所有需要查詢的lca,lq[i].first儲存v,second儲存查詢的id int a[maxn], b[maxn], ans[maxm]; int tot[maxn], in[maxn], fa1[maxn]; int fa[maxn], lca[3*maxm], col[maxn]; int bel[maxn],st[maxn],top=0; struct Query { int l, r, id; int st,ed; bool operator <(const Query& a) const { return st!=a.st? st<a.st: ed<a.ed; } }; Query q[maxm]; int tag; int dfs(int u, int par, int &cnt)//分塊 { fa1[u]=par; int num=0; for(int i=0; i<g[u].size(); i++){ int v=g[u][i]; if(v!=par) num+=dfs(v, u, cnt); if(num>=siz){ for(int i=0; i<num; i++) bel[st[--top]]=tag; tag++; num=0; } } st[top++]=u; return num+1; } int find(int u) { return fa[u]==u?u:(fa[u]=find(fa[u])); } int unite(int x, int y) { x=fa[x]; y=fa[y]; fa[y]=x; } void dfs2(int u, int par)//離線查詢所有lca { col[u]=1; for(int i=first[u]; i!=-1; i=nxt[i]){ int v=lq[i].fir, id=lq[i].sec; if(!col[v]) continue; else if(col[v]==1){ lca[id]=v; } else{ lca[id]=find(v); } } for(int i=0; i<g[u].size(); i++){ int v=g[u][i]; if(v!=par) dfs2(v, u); } col[u]=2; unite(par, u); } void add(int u, int v, int id)//查詢m<=1e5,數比較多所以用前向星實現優化 { lq[ltot]=P(v,id); nxt[ltot]=first[u]; first[u]=ltot++; } void init() { for(int i=0; i<=n; i++) g[i].clear(); memset(tot, 0, sizeof(tot)); memset(in, 0, sizeof(in)); siz=sqrt(n); for(int i=1;i<=n; i++) scanf("%d", a+i), b[i]=a[i]; sort(b+1, b+n+1); for(int i=1; i<=n; i++) a[i]=lower_bound(b+1, b+n+1, a[i])-b; for(int i=0; i<n-1; i++){ int u,v; scanf("%d%d", &u, &v); g[u].push_back(v); g[v].push_back(u); } int cnt=0; top=0; tag=0; int num=dfs(1, -1, cnt); for(int i=0; i<num; i++) bel[st[--top]]=tag; for(int i=0; i<m; i++){ scanf("%d%d", &q[i].l, &q[i].r); if(bel[q[i].l]>bel[q[i].r]) swap(q[i].l, q[i].r); q[i].id=i; q[i].st=bel[q[i].l]; q[i].ed=bel[q[i].r]; } sort(q, q+m); cnt=0; ltot=0; memset(first, -1, sizeof(first)); add(1, q[0].l, cnt); add(q[0].l, 1, cnt++); add(1, q[0].r, cnt); add(q[0].r, 1, cnt++); add(q[0].r, q[0].l, cnt); add(q[0].l, q[0].r, cnt++); //add(q[0].r, q[0].l, cnt++); for(int i=0; i<m-1; i++){ add(q[i].l, q[i+1].l, cnt);//第i個查詢左端點向第i+1個左端點轉移,所以需要它們之間的lca add(q[i+1].l, q[i].l, cnt++); add(q[i].r, q[i+1].r, cnt);//第i個查詢右端點向第i+1個右端點轉移 add(q[i+1].r, q[i].r, cnt++); add(q[i+1].r, q[i+1].l, cnt);//左端點和右端點的lca add(q[i+1].l, q[i+1].r,cnt++); } for(int i=0; i<=n; i++) fa[i]=i; memset(col, 0, sizeof(col)); dfs2(1, 0); } void solve() { int res=0; int cu=1, cv=1; for(int i=0; i<m; i++){ int nu=q[i].l, nv=q[i].r; //cout<<lca[i*3]<<' '<<lca[i*3+1]<<' '<<lca[i*3+2]<<endl; int par=lca[i*3]; while(cu!=par){ if(in[cu]){ if(--tot[a[cu]]==0) res--; } else if(++tot[a[cu]]==1) res++; in[cu]^=1; cu=fa1[cu]; } cu=nu; while(cu!=par){ if(in[cu]){ if(--tot[a[cu]]==0) res--; } else if(++tot[a[cu]]==1) res++; in[cu]^=1; cu=fa1[cu]; } cu=nu; par=lca[i*3+1]; while(cv!=par){ if(in[cv]){ if(--tot[a[cv]]==0) res--; } else if(++tot[a[cv]]==1) res++; in[cv]^=1; cv=fa1[cv]; } cv=nv; while(cv!=par){ if(in[cv]){ if(--tot[a[cv]]==0) res--; } else if(++tot[a[cv]]==1) res++; in[cv]^=1; cv=fa1[cv]; } cv=nv; par=lca[i*3+2]; ans[q[i].id]=res+(!tot[a[par]]); } } int main() { while(cin>>n>>m){ init(); solve(); for(int i=0; i<m; i++) printf("%d\n", ans[i]); } return 0; }