關於那些平衡樹的板子
阿新 • • 發佈:2020-08-12
Treap
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> #include <ctime> using namespace std; inline int read(){ int x = 0, w = 1; char ch = getchar(); for(; ch > '9' || ch < '0'; ch = getchar()) if(ch == '-') w = -1; for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0'; return x * w; } const int ss = 100010; const int inf = 0x7ffffff; struct treap{ int l, r;//左右子節點的下標 int val, data;//關鍵碼,權值 int cnt, size; //副本數,子樹大小 }a[ss]; int tot, root, n, m; int s[ss]; inline int New(int val){ a[++tot].val = val; a[tot].data = s[tot]; a[tot].cnt = a[tot].size = 1; return tot; } inline void update(int p){ a[p].size = a[a[p].l].size + a[a[p].r].size + a[p].cnt; } inline void build(){ New(-inf), New(inf); root = 1, a[1].r = 2; update(root); } inline int rank(int p, int val){ if(p == 0) return 0; if(val == a[p].val) return a[a[p].l].size + 1; if(val < a[p].val) return rank(a[p].l, val); return rank(a[p].r, val) + a[a[p].l].size + a[p].cnt; } inline int val(int p, int rank){ if(p == 0) return inf; if(a[a[p].l].size >= rank) return val(a[p].l, rank); if(a[a[p].l].size + a[p].cnt >= rank) return a[p].val; return val(a[p].r, rank - a[a[p].l].size - a[p].cnt); } inline void zig(int &p){ int q = a[p].l; a[p].l = a[q].r; a[q].r = p; p = q; update(a[p].r); update(p); } inline void zag(int &p){ int q = a[p].r; a[p].r = a[q].l; a[q].l = p; p = q; update(a[p].l); update(p); } inline void insert(int &p, int val){ if(p == 0){ p = New(val); return; } if(val == a[p].val){ a[p].cnt++; update(p); return; } if(val < a[p].val){ insert(a[p].l, val); if(a[p].data < a[a[p].l].data) zig(p);//右旋 } else{ insert(a[p].r, val); if(a[p].data < a[a[p].r].data) zag(p);//左旋 } update(p); } inline int getpre(int val){ int ans = 1;//a[1].val = inf; int p = root; while(p){ if(val == a[p].val){ if(a[p].l > 0){ p = a[p].l; while(a[p].r > 0) p = a[p].r; ans = p; } break; } if(a[p].val < val && a[p].val > a[ans].val) ans = p; p = val < a[p].val ? a[p].l : a[p].r; } return a[ans].val; } inline int getnxt(int val){ int ans = 2; //a[2].val = inf; int p = root; while(p){ if(val == a[p].val){ if(a[p].r > 0){ p = a[p].r; while(a[p].l > 0) p = a[p].l; ans = p; } break; } if(a[p].val > val && a[p].val < a[ans].val) ans = p; p = val < a[p].val ? a[p].l : a[p].r; } return a[ans].val; } inline void remove(int &p, int val){ if(p == 0) return; if(val == a[p].val){ if(a[p].cnt > 1){ a[p].cnt--; update(p); return; } if(a[p].l || a[p].r){ if(a[p].r == 0 || a[a[p].l].data > a[a[p].r].data) zig(p), remove(a[p].r, val); else zag(p), remove(a[p].l, val); update(p); } else p = 0; } val < a[p].val ? remove(a[p].l, val) : remove(a[p].r, val); update(p); return; } signed main(){ build(); n = read(), m = read(); for(int i = 1; i <= n; i++){ s[i] = read(); insert(root, s[i]); } int ans = 0; bool flag = 0; while(m--){ int op = read(), x = read(); if(op == 1) insert(root, x); else if(op == 2) remove(root, x); else if(op == 3) { int tmp = rank(root, x) - 1; printf("%d\n", tmp); if(!flag) ans = tmp, flag = 1; else ans ^= tmp; } else if(op == 4) { int tmp = val(root, x + 1); printf("%d\n", tmp); if(!flag) ans = tmp, flag = 1; else ans ^= tmp; } else if(op == 5){ int tmp = getpre(x); printf("%d\n", tmp); if(!flag) ans = tmp, flag = 1; else ans ^= tmp; } else{ int tmp = getnxt(x); printf("%d\n", tmp); if(!flag) ans = tmp, flag = 1; else ans ^= tmp; } } cout << "ans = " << ans << endl; return 0; }