P2633 Count on a tree(主席樹)
阿新 • • 發佈:2020-08-05
只是轉化成樹上問題, 同樣是動態開點維護
#include<bits/stdc++.h> #define getsz(p) (p?p->sz:0) #define getlsz(p) (p?getsz(p->ls):0) #define getl(p) (p?p->ls:0) #define getr(p) (p?p->rs:0) using namespace std; typedef long long ll; const int N=4e5+10; int a[N]; int depth[N],f[N][25]; int h[N],ne[N],e[N],idx;View Codeint n; void add(int a,int b){ e[idx]=b,ne[idx]=h[a],h[a]=idx++; } int st[N]; struct node{ int l,r; int sz; node *ls,*rs; void update(){ sz = getsz(ls) + getsz(rs); } }*rt[N],pool[N*30]; vector<int> num; node * copynode(node *rt){ node *p=pool+(++idx); pool[idx]=*rt; return p; } node * newnode(int l,int r){ node *p=pool+(++idx); p->l=l,p->r=r; return p; } node *insert(node *rt,int l,int r,int x){ node *p; if(rt) p=copynode(rt); else p=newnode(l,r); p->sz++; int mid=l+r>>1; if(p->l==x&&p->r==x){return p; } if(x<=mid) p->ls=insert(p->ls,l,mid,x); else p->rs=insert(p->rs,mid+1,r,x); return p; } int find(int x){ return lower_bound(num.begin(),num.end(),x)-num.begin()+1; } void dfs(int u,int fa){ st[u]=1; rt[u]=insert(rt[fa],1,n,find(a[u])); int i; for(i=1;i<=20;i++){ if(depth[u]<=(1<<i)) break; f[u][i]=f[f[u][i-1]][i-1]; } for(i=h[u];i!=-1;i=ne[i]){ int j=e[i]; if(j==fa||st[j]) continue; depth[j]=depth[u]+1; f[j][0]=u; dfs(j,u); } } int lca(int a,int b){ if(depth[a]<depth[b]) swap(a,b); int i; for(i=20;i>=0;i--){ if(depth[f[a][i]]>=depth[b]){ a=f[a][i]; } } if(a==b) return a; for(i=20;i>=0;i--){ if(f[a][i]!=f[b][i]){ a=f[a][i]; b=f[b][i]; } } return f[a][0]; } int query(node* pL, node* pR, node* p0, node* p1, int k){ if(pR && pR->l==pR->r) return pR->l; if(pL && pL->l==pL->r) return pL->l; int k1 = getlsz(pL) + getlsz(pR) - getlsz(p0) - getlsz(p1); if(k1 >= k) return query(getl(pL), getl(pR), getl(p0), getl(p1), k); else return query(getr(pL), getr(pR), getr(p0), getr(p1), k - k1); } int main(){ ios::sync_with_stdio(false); int m; cin>>n>>m; int i; memset(h,-1,sizeof h); for(i=1;i<=n;i++){ cin>>a[i]; num.push_back(a[i]); } for(i=1;i<n;i++){ int u,v; cin>>u>>v; add(u,v); add(v,u); } sort(num.begin(),num.end()); num.erase(unique(num.begin(),num.end()),num.end()); n=(int)num.size(); int last=0; depth[1]=1; dfs(1,0); while(m--){ int u,v,k; cin>>u>>v>>k; u^=last; int p=lca(u,v); last=num[query(rt[u],rt[v],rt[p],rt[f[p][0]],k)-1]; cout<<last<<endl; } return 0; }