又來主席樹第一發:dfs序-代代宗師
阿新 • • 發佈:2020-12-30
#include<iostream> #include<vector> #include<string.h> #define N 105000 #define L 2*N using namespace std; vector<int>vec[N]; int v[2*N],vl,st[N],ed[N]; int sur[N],hgt[N],mxhgt; bool vis[N]; void dfs(int x,int h) { vis[x]=true; v[++vl]=x;st[x]=vl;hgt[vl]=h; if (h>mxhgt)mxhgt=h; for (int i=0;i<vec[x].size();i++) if (!vis[vec[x][i]])dfs(vec[x][i],h+1); v[++vl]=x;ed[x]=vl;hgt[vl]=-1; } struct PD_tree { struct node { int sum; int l,r; node *lc,*rc; node(){sum=0;l=r=0;} node(int h1,int h2,int h3){l=h1,r=h2;sum=h3;} }; node *root[L]; void build(int l,int r,node* &pnt) { pnt=new node(l,r,0); if (l<r){ int mid=(l+r)>>1; build(l,mid,pnt->lc);build(mid+1,r,pnt->rc); } else pnt->lc=pnt->rc=NULL; } void update(node* last,node* &now,int x) { now=new node(last->l,last->r,last->sum); if (now->l==now->r){ if ((now->l)==x)(now->sum)++; //if (x==6)cerr<<"YES\n"; return; } int mid=((last->l)+(last->r))>>1; if (x>mid){now->lc=last->lc;update(last->rc,now->rc,x);} else {now->rc=last->rc;update(last->lc,now->lc,x);} now->sum=(now->lc->sum)+(now->rc->sum); //cerr<<last->l<<" "<<last->r<<" "<<x<<"\n"; } int find(node* now,int x) { if (now->l==now->r)return now->sum; int mid=(now->l+now->r)>>1; if (x<=mid)return find(now->lc,x); else return find(now->rc,x); } void display(node* now) { if (now->l==now->r){cout<<now->sum<<" ";return;} display(now->lc);display(now->rc); } }t; int main() { ios::sync_with_stdio(false); int n,q,i; cin>>n>>q; for (i=1;i<=n;i++)cin>>sur[i]; for (i=1;i<n;i++){ int v1,v2; cin>>v1>>v2; vec[v1].push_back(v2);vec[v2].push_back(v1); } vl=0;mxhgt=0; memset(vis,false,sizeof(vis)); dfs(1,1); t.build(1,mxhgt,t.root[0]); for (i=1;i<=vl;i++){ if (sur[v[i]]&&hgt[i]!=-1)t.update(t.root[i-1],t.root[i],hgt[i]); else t.root[i]=t.root[i-1]; //cerr<<"OOPS:"<<v[i]<<" "<<t.find(t.root[i],hgt[i])<<"\n"; } /*for (i=1;i<=vl;i++) {cerr<<v[i]<<" ";t.display(t.root[i]);cerr<<"\n";}*/ for (i=1;i<=q;i++){ int vi,hi; cin>>vi>>hi; //cerr<<"U:"<<ed[vi]<<" "<<st[vi]<<"\n"; cout<<(t.find(t.root[ed[vi]],hi)-t.find(t.root[st[vi]],hi))<<"\n"; } return 0; }