1. 程式人生 > >bzoj 3224 splay板子

bzoj 3224 splay板子

覺得模組化起來挺好看的,也不容易錯^_^

bzoj 3224 普通平衡樹

#include <cstdio>
#define lf ch[x][0]
#define rg ch[x][1]
#define rep(i,j,k) for (i=j;i<=k;i++)
const int inf=2e9+1,N=1e5+5;
using namespace std;
int n,i,od,x,o;
struct bst{
	int root,pn,tp,cs,stk[N],sz[N],w[N],c[N],ch[N][2],fa[N];
	int sd(int x) {return ch[fa[x]][1]==x;}
	inline void updata(int x) { sz[x]=sz[lf]+sz[rg]+c[x]; }
	inline int newnode(int W) {
		int p;
		if (tp>0) p=stk[tp--];
		else p=++pn;
		sz[p]=1; w[p]=W; c[p]=1; ch[p][0]=ch[p][1]=0; fa[p]=0;
		return p;
	}
	void init() {
		tp=pn=cs=0; newnode(-inf); newnode(inf); root=1;
		sz[1]=2; ch[1][1]=2; sz[2]=1; fa[2]=1;
	}
	void rotate(int x)
	{
		int f=fa[x],gf=fa[f],gs=sd(f),xs=sd(x);
		fa[f]=x; ch[gf][gs]=x; fa[x]=gf; ch[f][xs]=ch[x][xs^1];
		fa[ch[x][xs^1]]=f; ch[x][xs^1]=f;
		updata(f); updata(x);
	}
	void splay(int x,int y)
	{
		int sd1,sd2;
		if (x==y) return;
		while (fa[x]!=y)
		{
			if (fa[fa[x]]==y) rotate(x);
			else {
				sd1=sd(x); sd2=sd(fa[x]);
				if (sd1^sd2) rotate(x); else rotate(fa[x]);
				rotate(x);
			}
		}
		if (!fa[x]) root=x;
	}
	void insert(int W)
	{
		int x=root,lst=0,lsd=0;
		while (1) {
			if (w[x]==W) { c[x]++; sz[x]++; splay(x,0); return; }
			if (!x) { x=newnode(W); fa[x]=lst; ch[lst][lsd]=x; splay(x,0); return; }
			if (w[x]<W) { lst=x; lsd=1; x=rg; }
			else { lst=x; lsd=0; x=lf; }
		}
	}
	int find_kth(int rt,int k) //ÎÞ·¨ÅжÏÕÒ²»µ½µÄÇé¿öŶ 
	{
		int x=rt;
		while (1)
		{
			if (sz[lf]<k && sz[lf]+c[x]>=k) {
				splay(x,fa[rt]); return x;
			}
			else if (sz[lf]>=k) x=lf;
			else {k-=sz[lf]+c[x]; x=rg;}
		}
	}
	void del(int x)
	{
		int suc;
		c[x]--; splay(x,0);
		if (c[x]>0) return;
		suc=find_kth(rg,1);
		root=suc; ch[suc][0]=lf; fa[lf]=suc; fa[suc]=0; stk[++tp]=x;
		updata(suc);
	}
	int find_l(int W)
	{
		int x=root,ll=0;
		while (1) {
			if (!x) break;
			if (w[x]==W) {ll=x; break;}
			if (w[x]<W) ll=x,x=rg; else x=lf;
		}
		splay(ll,0); return ll;
	}
	int find_r(int W)
	{
		int x=root,rr=0;
		while (1) {
			if (!x) break;
			if (w[x]==W) { rr=x; break;}
			if (w[x]<W) x=rg; else rr=x,x=lf;
		}
		splay(rr,0); return rr;
	}
	int find_rank(int W) {
		int x=find_l(W);
		if (w[x]==W) return sz[lf];
		return c[x]+sz[lf];
	}
	int find_pre(int W) {
		int x=find_l(W),y;
		if (w[x]==W) y=find_kth(lf,sz[lf]);
		else y=x;
		return w[y];
	}
	int find_suc(int W) {
		int x=find_r(W),y;
		if (w[x]==W) y=find_kth(rg,1);
		else y=x;
		return w[y];
	}
	void dfs(int x)
	{
		printf("id=%d w=%d sz=%d lc=%d rc=%d fa=%d\n",x,w[x],sz[x],ch[x][0],ch[x][1],fa[x]);
		if (ch[x][0]) dfs(ch[x][0]);
		if (ch[x][1]) dfs(ch[x][1]);
	}
	void outit() {
		cs++; printf("Case #%d:\n",cs);
		dfs(root);
		printf("\n");
	}
}tr;
void read(int &ret)
{
	char ch; int sgn=1; ret=0;
	for (ch=getchar();ch<'0' || ch>'9';ch=getchar()) if (ch=='-') sgn=-1;
	for (;ch>='0' && ch<='9';ch=getchar()) ret=ret*10+ch-'0';
	ret*=sgn;
}
int main()
{
//	freopen("bst.in","r",stdin);
//	freopen("bst.out","w",stdout);
	tr.init();
	read(n);
	rep(i,1,n)
	{
		read(od); read(x);
		if (od==1) tr.insert(x);
		if (od==2) tr.del(tr.find_l(x));
		if (od==3) o=tr.find_rank(x);
		if (od==4) o=tr.w[tr.find_kth(tr.root,x+1)];
		if (od==5) o=tr.find_pre(x);
		if (od==6) o=tr.find_suc(x);
		if (od>2) printf("%d\n",o);
//		tr.outit();
	}
	return 0;
}