bzoj 3224 splay板子
阿新 • • 發佈:2018-12-10
覺得模組化起來挺好看的,也不容易錯^_^
bzoj 3224 普通平衡樹
#include <cstdio> #define lf ch[x][0] #define rg ch[x][1] #define rep(i,j,k) for (i=j;i<=k;i++) const int inf=2e9+1,N=1e5+5; using namespace std; int n,i,od,x,o; struct bst{ int root,pn,tp,cs,stk[N],sz[N],w[N],c[N],ch[N][2],fa[N]; int sd(int x) {return ch[fa[x]][1]==x;} inline void updata(int x) { sz[x]=sz[lf]+sz[rg]+c[x]; } inline int newnode(int W) { int p; if (tp>0) p=stk[tp--]; else p=++pn; sz[p]=1; w[p]=W; c[p]=1; ch[p][0]=ch[p][1]=0; fa[p]=0; return p; } void init() { tp=pn=cs=0; newnode(-inf); newnode(inf); root=1; sz[1]=2; ch[1][1]=2; sz[2]=1; fa[2]=1; } void rotate(int x) { int f=fa[x],gf=fa[f],gs=sd(f),xs=sd(x); fa[f]=x; ch[gf][gs]=x; fa[x]=gf; ch[f][xs]=ch[x][xs^1]; fa[ch[x][xs^1]]=f; ch[x][xs^1]=f; updata(f); updata(x); } void splay(int x,int y) { int sd1,sd2; if (x==y) return; while (fa[x]!=y) { if (fa[fa[x]]==y) rotate(x); else { sd1=sd(x); sd2=sd(fa[x]); if (sd1^sd2) rotate(x); else rotate(fa[x]); rotate(x); } } if (!fa[x]) root=x; } void insert(int W) { int x=root,lst=0,lsd=0; while (1) { if (w[x]==W) { c[x]++; sz[x]++; splay(x,0); return; } if (!x) { x=newnode(W); fa[x]=lst; ch[lst][lsd]=x; splay(x,0); return; } if (w[x]<W) { lst=x; lsd=1; x=rg; } else { lst=x; lsd=0; x=lf; } } } int find_kth(int rt,int k) //ÎÞ·¨ÅжÏÕÒ²»µ½µÄÇé¿öŶ { int x=rt; while (1) { if (sz[lf]<k && sz[lf]+c[x]>=k) { splay(x,fa[rt]); return x; } else if (sz[lf]>=k) x=lf; else {k-=sz[lf]+c[x]; x=rg;} } } void del(int x) { int suc; c[x]--; splay(x,0); if (c[x]>0) return; suc=find_kth(rg,1); root=suc; ch[suc][0]=lf; fa[lf]=suc; fa[suc]=0; stk[++tp]=x; updata(suc); } int find_l(int W) { int x=root,ll=0; while (1) { if (!x) break; if (w[x]==W) {ll=x; break;} if (w[x]<W) ll=x,x=rg; else x=lf; } splay(ll,0); return ll; } int find_r(int W) { int x=root,rr=0; while (1) { if (!x) break; if (w[x]==W) { rr=x; break;} if (w[x]<W) x=rg; else rr=x,x=lf; } splay(rr,0); return rr; } int find_rank(int W) { int x=find_l(W); if (w[x]==W) return sz[lf]; return c[x]+sz[lf]; } int find_pre(int W) { int x=find_l(W),y; if (w[x]==W) y=find_kth(lf,sz[lf]); else y=x; return w[y]; } int find_suc(int W) { int x=find_r(W),y; if (w[x]==W) y=find_kth(rg,1); else y=x; return w[y]; } void dfs(int x) { printf("id=%d w=%d sz=%d lc=%d rc=%d fa=%d\n",x,w[x],sz[x],ch[x][0],ch[x][1],fa[x]); if (ch[x][0]) dfs(ch[x][0]); if (ch[x][1]) dfs(ch[x][1]); } void outit() { cs++; printf("Case #%d:\n",cs); dfs(root); printf("\n"); } }tr; void read(int &ret) { char ch; int sgn=1; ret=0; for (ch=getchar();ch<'0' || ch>'9';ch=getchar()) if (ch=='-') sgn=-1; for (;ch>='0' && ch<='9';ch=getchar()) ret=ret*10+ch-'0'; ret*=sgn; } int main() { // freopen("bst.in","r",stdin); // freopen("bst.out","w",stdout); tr.init(); read(n); rep(i,1,n) { read(od); read(x); if (od==1) tr.insert(x); if (od==2) tr.del(tr.find_l(x)); if (od==3) o=tr.find_rank(x); if (od==4) o=tr.w[tr.find_kth(tr.root,x+1)]; if (od==5) o=tr.find_pre(x); if (od==6) o=tr.find_suc(x); if (od>2) printf("%d\n",o); // tr.outit(); } return 0; }