AcWing 253. 普通平衡樹
阿新 • • 發佈:2022-05-25
平衡樹
treap模板
#include<bits/stdc++.h> using namespace std; const int N = 1e5+10,INF = 0x3f3f3f3f; int idx; struct NODE{ int l,r; int key,val; int cnt,size; }node[N]; int newnode(int key){ node[++idx].key=key; node[idx].val=rand(); node[idx].size=node[idx].cnt=1; return idx; } void push_up(int p){ node[p].size=node[node[p].l].size+node[node[p].r].size+node[p].cnt; } void zig(int &p){ int q=node[p].l; node[p].l=node[q].r; node[q].r=p; p=q;// push_up(node[p].r),push_up(p); } void zag(int &p){ int q=node[p].r; node[p].r=node[q].l; node[q].l=p; p=q;// push_up(node[p].l),push_up(p); } int init(){ newnode(-INF),newnode(INF); int root=1; node[1].r=2; push_up(root); if(node[1].val<node[2].val) zag(root); return root; } void my_insert(int &p,int x){ if(!p) p=newnode(x); else if(node[p].key==x) node[p].cnt++; else if(node[p].key>x){ my_insert(node[p].l,x); if(node[p].val<node[node[p].l].val) zig(p); } else{ my_insert(node[p].r,x); if(node[p].val<node[node[p].r].val) zag(p); } push_up(p); } void my_remove(int &p,int x){ if(!p) return; else if(node[p].key==x){ if(node[p].cnt>1) node[p].cnt--; else if(node[p].l || node[p].r){// if(!node[p].r || node[node[p].l].val>node[node[p].r].val){ zig(p); my_remove(node[p].r,x); } else{ zag(p); my_remove(node[p].l,x); } } else p=0; } else{ if(node[p].key>x) my_remove(node[p].l,x); else my_remove(node[p].r,x); } push_up(p); } int get_rank_by_key(int p,int x){ if(!p) return 0; else if(node[p].key==x) return node[node[p].l].size+1; else if(node[p].key>x) return get_rank_by_key(node[p].l,x); else return get_rank_by_key(node[p].r,x)+node[node[p].l].size+node[p].cnt; } int get_key_by_rank(int p,int x){ if(!p) return INF+1; else if(node[node[p].l].size>=x) return get_key_by_rank(node[p].l,x); else if(node[node[p].l].size+node[p].cnt>=x) return node[p].key; else return get_key_by_rank(node[p].r,x-node[node[p].l].size-node[p].cnt); } int get_prev(int p,int x){ if(!p) return -INF; else if(node[p].key>=x) return get_prev(node[p].l,x); else return max(node[p].key,get_prev(node[p].r,x)); } int get_next(int p,int x){ if(!p) return INF; else if(node[p].key<=x) return get_next(node[p].r,x); else return min(node[p].key,get_next(node[p].l,x)); } int main(){ int root=init(); int n,x,ops; cin>>n; while(n--){ cin>>ops>>x; if(ops==1) my_insert(root,x); else if(ops==2) my_remove(root,x); else if(ops==3) cout<<get_rank_by_key(root,x)-1<<endl; else if(ops==4) cout<<get_key_by_rank(root,x+1)<<endl; else if(ops==5) cout<<get_prev(root,x)<<endl; else cout<<get_next(root,x)<<endl; } return 0; }