1. 程式人生 > 實用技巧 >關於那些平衡樹的板子

關於那些平衡樹的板子

Treap

luoguP6136

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <ctime>
using namespace std;

inline int read(){
	int x = 0, w = 1;
	char ch = getchar();
	for(; ch > '9' || ch < '0'; ch = getchar()) if(ch == '-') w = -1;
	for(; ch >= '0' && ch <= '9'; ch = getchar()) x = x * 10 + ch - '0';
	return x * w;
}

const int ss = 100010;
const int inf = 0x7ffffff;

struct treap{
	int l, r;//左右子節點的下標
	int val, data;//關鍵碼,權值
	int cnt, size; //副本數,子樹大小
}a[ss];
int tot, root, n, m;

int s[ss];
inline int New(int val){
	a[++tot].val = val;
	a[tot].data = s[tot];
	a[tot].cnt = a[tot].size = 1;
	return tot;
}

inline void update(int p){
	a[p].size = a[a[p].l].size + a[a[p].r].size + a[p].cnt;
}

inline void build(){
	New(-inf), New(inf);
	root = 1, a[1].r = 2;
	update(root);
}

inline int rank(int p, int val){
	if(p == 0) return 0;
	if(val == a[p].val) return a[a[p].l].size + 1;
	if(val < a[p].val) return rank(a[p].l, val);
	return rank(a[p].r, val) + a[a[p].l].size + a[p].cnt;
}

inline int val(int p, int rank){
	if(p == 0) return inf;
	if(a[a[p].l].size >= rank) return val(a[p].l, rank);
	if(a[a[p].l].size + a[p].cnt >= rank) return a[p].val;
	return val(a[p].r, rank - a[a[p].l].size - a[p].cnt);
}

inline void zig(int &p){
	int q = a[p].l;
	a[p].l = a[q].r;
	a[q].r = p;
	p = q;
	update(a[p].r);
	update(p);
}

inline void zag(int &p){
	int q = a[p].r;
	a[p].r = a[q].l;
	a[q].l = p;
	p = q;
	update(a[p].l);
	update(p);
}

inline void insert(int &p, int val){
	if(p == 0){
		p = New(val);
		return;
	}
	if(val == a[p].val){
		a[p].cnt++;
		update(p);
		return;
	}
	if(val < a[p].val){
		insert(a[p].l, val);
		if(a[p].data < a[a[p].l].data) zig(p);//右旋		
	}
	else{
		insert(a[p].r, val);
		if(a[p].data < a[a[p].r].data) zag(p);//左旋
	}
	update(p);
}

inline int getpre(int val){
	int ans = 1;//a[1].val = inf;
	int p = root;
	while(p){
		if(val == a[p].val){
			if(a[p].l > 0){
				p = a[p].l;
				while(a[p].r > 0) p = a[p].r;
				ans = p;
			}
			break;
		}
		if(a[p].val < val && a[p].val > a[ans].val) ans = p;
		p = val < a[p].val ? a[p].l : a[p].r;
	}
	return a[ans].val;
}

inline int getnxt(int val){
	int ans = 2; //a[2].val = inf;
	int p = root;
	while(p){
		if(val == a[p].val){
			if(a[p].r > 0){
				p = a[p].r;
				while(a[p].l > 0) p = a[p].l;
				ans = p;
			}
			break;
		}
		if(a[p].val > val && a[p].val < a[ans].val) ans = p;
		p = val < a[p].val ? a[p].l : a[p].r;
	}
	return a[ans].val;
}

inline void remove(int &p, int val){
	if(p == 0) return;
	if(val == a[p].val){
		if(a[p].cnt > 1){
			a[p].cnt--;
			update(p);
			return;
		}
		if(a[p].l || a[p].r){
			if(a[p].r == 0 || a[a[p].l].data > a[a[p].r].data)
				zig(p), remove(a[p].r, val);
			else
				zag(p), remove(a[p].l, val);
			update(p);
		}
		else p = 0;
	}
	val < a[p].val ? remove(a[p].l, val) : remove(a[p].r, val);
	update(p);
	return;
}

signed main(){
	build();
	n = read(), m = read();
	for(int i = 1; i <= n; i++){
		s[i] = read();
		insert(root, s[i]);
	}
	int ans = 0;
	bool flag = 0;
	while(m--){
		int op = read(), x = read();
		if(op == 1)	insert(root, x);
		else if(op == 2) remove(root, x);
		else if(op == 3) {
			int tmp = rank(root, x) - 1;
			printf("%d\n", tmp);
			if(!flag) ans = tmp, flag = 1;
			else ans ^= tmp;
		}
		else if(op == 4) {
			int tmp = val(root, x + 1);
			printf("%d\n", tmp);
			if(!flag) ans = tmp, flag = 1;
			else ans ^= tmp;
		}
		else if(op == 5){
			int tmp = getpre(x);
			printf("%d\n", tmp);
			if(!flag) ans = tmp, flag = 1;
			else ans ^= tmp;
		}
		else{
			int tmp = getnxt(x);
			printf("%d\n", tmp);
			if(!flag) ans = tmp, flag = 1;
			else ans ^= tmp;
		}
	}
	cout << "ans = " << ans << endl;
	return 0;
}