1. 程式人生 > >平衡樹模板三題

平衡樹模板三題

您需要寫一種資料結構(可參考題目標題),來維護一個序列,其中需要提供以下操作:翻轉一個區間,例如原有序序列是 5 4 3 2 1,翻轉區間是 [2,4] 的話,結果是 5 2 3 4 1

支援reverse(交換左右子樹)的Splay。

C++程式碼

#include <cstdio>

template <class T>
inline void swap(T &a, T &b) {
	register T c;
	c = a;
	a = b;
	b = c;
}

const int mxn = 1000005;

int n, m;

int root;

int fat[mxn];
int lsn[mxn];
int rsn[mxn];
int rev[mxn];
int siz[mxn];

int build(int l, int r) {
	int t = (l + r) >> 1;
	
	if (l < t)fat[lsn[t] = build(l, t - 1)] = t;
	if (r > t)fat[rsn[t] = build(t + 1, r)] = t;
	
	siz[t] = siz[lsn[t]] + siz[rsn[t]] + 1;
	
	return t;
}

inline void update(int t) {
	siz[t] = siz[lsn[t]] + siz[rsn[t]] + 1;
}

inline void pushdown(int t) {
	if (t && rev[t]) {
		rev[t] = 0;
		
		swap(lsn[t], rsn[t]);
		
		if (lsn[t])rev[lsn[t]] ^= 1;
		if (rsn[t])rev[rsn[t]] ^= 1;
	}
}

inline void pushdown(int t, int p) {
	static int stk[mxn], top;
	
	for (stk[++top] = t; t != p; )
		stk[++top] = t = fat[t];
	
	while (top)pushdown(stk[top--]);
}

inline void rotate(int t) {
	int f = fat[t];
	int g = fat[f];
	
	if (t == lsn[f]) {
		lsn[f] = rsn[t];
		if (rsn[t])
			fat[rsn[t]] = f;
		
		rsn[t] = f;
		fat[f] = t;
		fat[t] = g;

		if (g) {
			if (lsn[g] == f)
				lsn[g] = t;
			else
				rsn[g] = t;
		}
		else
			root = t;
	}
	else {
		rsn[f] = lsn[t];
		if (lsn[t])
			fat[lsn[t]] = f;
		
		lsn[t] = f;
		fat[f] = t;
		fat[t] = g;
		
		if (g) {
			if (lsn[g] == f)
				lsn[g] = t;
			else
				rsn[g] = t;
		}
		else
			root = t;
	}
	
	update(f);
	update(t);
}

inline void splay(int t, int p) {
	pushdown(t, p);
	
	while (fat[t] != p) {
		int f = fat[t];
		int g = fat[f];
		
		if (g == p)
			rotate(t);
		else {
			int a = lsn[f] == t;
			int b = lsn[g] == f;
			
			if (a == b)
				rotate(f);
			else
				rotate(t);
			
			rotate(t);
		}
	}
}

inline int kth(int k) {
	for (int t = root; t; ) {
		pushdown(t);
		
		int s = siz[lsn[t]];
		
		if (k <= s)
			t = lsn[t];
		else if (k == ++s)
			return t;
		else
			t = rsn[t], k -= s;
	}
	
	return 0;
}

void travel(int t) {
	pushdown(t);
	
	if (lsn[t])travel(lsn[t]);
	
	if (t > 1 && t < n + 2)
		printf("%d ", t - 1);
	
	if (rsn[t])travel(rsn[t]);
}

signed main() {
	scanf("%d%d", &n, &m);
	
	root = build(1, n + 2);
	
	for (int i = 1, l, r; i <= m; ++i) {
		scanf("%d%d", &l, &r), ++l, ++r;
		
		int a = kth(l - 1);
		int b = kth(r + 1);
		
		splay(a, 0);
		splay(b, a);
		
		rev[lsn[b]] ^= 1;
	}
	
	travel(root);
}

二逼平衡樹

您需要寫一種資料結構(可參考題目標題),來維護一個有序數列,其中需要提供以下操作:

  1. 查詢 x xx 在區間內的排名;
  2. 查詢區間內排名為 k kk 的值;
  3. 修改某一位置上的數值;
  4. 查詢 x xx 在區間內的前趨(前趨定義為小於 x xx,且最大的數);
  5. 查詢 x xx 在區間內的後繼(後繼定義為大於 x xx,且最小的數)。
這題一般來說需要樹套樹,比如線段樹套Splay平衡樹。

C++程式碼

#include <cstdio>

#define min(a, b) (a < b ? a : b)
#define max(a, b) (a > b ? a : b)

template <class T>
inline T getmin(const T &a, const T &b) {
	return min(a, b);
}

template <class T>
inline T getmax(const T &a, const T &b) {
	return max(a, b);
}

const int mxn = 100005;
const int mxm = 5000005;
const int inf = 100000008;

namespace splay {
	int tot;
	int siz[mxm];
	int num[mxm];
	int fat[mxm];
	int son[mxm][2];
	
	int insert(int &t, int f, int v) {
		if (!t) {
			t = ++tot;
			fat[t] = f;
			num[t] = v;
			siz[t] = 1;
			son[t][0] = 0;
			son[t][1] = 0;
			
			return t;
		}
		
		++siz[t];
		
		if (v <= num[t])
			return insert(son[t][0], t, v);
		else
			return insert(son[t][1], t, v);
	}
	
	#define update(t) (siz[t] = siz[son[t][0]] + siz[son[t][1]] + 1)
	
	inline void rotate(int t) {
		int f = fat[t];
		int g = fat[f];
		int a = son[f][1] == t;
		int b = !a, s = son[t][b];
		
		fat[t] = g;
		fat[f] = t;
		
		son[t][b] = f;
		son[f][a] = s;
		
		if (s)
			fat[s] = f;
		if (g) {
			if (son[g][0] == f)
				son[g][0] = t;
			else
				son[g][1] = t;
		}
		
		update(f);
		update(t);
	}
	
	inline void splay(int t, int p = 0) {
		while (fat[t] != p) {
			int f = fat[t];
			int g = fat[f];
			
			if (g == p)
				rotate(t);
			else {
				int a = son[f][1] == t;
				int b = son[g][1] == f;
				
				if (a == b)
					rotate(f), rotate(t);
				else
					rotate(t), rotate(t);
			}
		}
	}
	
	inline int rnk(int t, int v) {
		int k = 0;
		
		while (t) {
			if (v <= num[t])	
				t = son[t][0];
			else {
				k += siz[t];
				
				t = son[t][1];
				
				k -= siz[t];
			}
		}
		
		return k;
	}
	
	inline int kth(int t, int k) {
		while (t) {
			int s = siz[son[t][0]];
			
			if (k <= s)
				t = son[t][0];
			else if (k == ++s)
				return t;
			else
				t = son[t][1], k -= s;
		}
		
		return 0;
	}
	
	inline int fnd(int t, int v) {
		while (t) {
			if (num[t] == v)
				return t;
			
			if (v < num[t])
				t = son[t][0];
			else
				t = son[t][1];
		}
		
		return 0;
	}
	
	inline int pre(int t, int v) {
		int r = -inf;
		
		while (t) {
			if (v <= num[t])
				t = son[t][0];
			else {
				r = max(r, num[t]);
				t = son[t][1];
			}
		}
		
		return r;
	}
	
	inline int nxt(int t, int v) {
		int r = +inf;
		
		while (t) {
			if (v >= num[t])
				t = son[t][1];
			else {
				r = min(r, num[t]);
				t = son[t][0];
			}
		}
		
		return r;
	}
	
	inline void ins(int &t, int v) {
		int p = insert(t, 0, v);
		splay(p, 0);
		t = p;
	}
	
	inline void rmv(int &t, int v) {
		int p = fnd(t, v);
		
		if (p) {
			splay(p, 0);
			
			if (!son[p][0])
				t = son[p][1], fat[t] = 0;
			else if (!son[p][1])
				t = son[p][0], fat[t] = 0;
			else {
				t = son[p][0];
				
				while (son[t][1])
					t = son[t][1];
				
				splay(t, p);
				
				son[t][1] = son[p][1];
				
				fat[son[p][1]] = t;
				
				fat[t] = 0;
				
				update(t);
			}
		}
	}
	
	void print(int t) {
		if (!t)return;
		
		print(son[t][0]);
		
		printf("%d: [%d] (%d) (%d %d) {%d}\n", 
			t, num[t], fat[t], son[t][0], son[t][1], siz[t]);
		
		print(son[t][1]);
	}
	
	#undef update
}

namespace stree {
	int tot;
	int rot[mxm];
	int lsn[mxm];
	int rsn[mxm];
	
	int build(int l, int r, int *s) {
		int t = ++tot;
		
		for (int i = l; i <= r; ++i)
			splay::ins(rot[t], s[i]);
		
		if (l != r) {
			int m = (l + r) >> 1;
			
			lsn[t] = build(l, m, s);
			rsn[t] = build(m + 1, r, s);
		}
		
		return t;
	}
	
	void update(int t, int l, int r, int p, int a, int b) {
		splay::rmv(rot[t], a);
		splay::ins(rot[t], b);
		
		if (l != r) {
			int m = (l + r) >> 1;
			
			if (p <= m)
				update(lsn[t], l, m, p, a, b);
			else
				update(rsn[t], m + 1, r, p, a, b);
		}
	}
	
	int queryRnk(int t, int l, int r, int x, int y, int v) {
		if (l == x && y == r)
			return splay::rnk(rot[t], v);
		
		int m = (l + r) >> 1;
		
		if (y <= m)
			return queryRnk(lsn[t], l, m, x, y, v);
		else if (x > m)
			return queryRnk(rsn[t], m + 1, r, x, y, v);
		
		return 
			queryRnk(lsn[t], l, m, x, m, v)
		+	queryRnk(rsn[t], m + 1, r, m + 1, y, v);
	}
	
	int queryPre(int t, int l, int r, int x, int y, int v) {
		if (l == x && y == r)
			return splay::pre(rot[t], v);
		
		int m = (l + r) >> 1;
		
		if (y <= m)
			return queryPre(lsn[t], l, m, x, y, v);
		else if (x > m)
			return queryPre(rsn[t], m + 1, r, x, y, v);
		
		return getmax(
			queryPre(lsn[t], l, m, x, m, v),
			queryPre(rsn[t], m + 1, r, m + 1, y, v));
	}
	
	int queryNxt(int t, int l, int r, int x, int y, int v) {
		if (l == x && y == r)
			return splay::nxt(rot[t], v);
		
		int m = (l + r) >> 1;
		
		if (y <= m)
			return queryNxt(lsn[t], l, m, x, y, v);
		else if (x > m)
			return queryNxt(rsn[t], m + 1, r, x, y, v);
		
		return getmin(
			queryNxt(lsn[t], l, m, x, m, v),
			queryNxt(rsn[t], m + 1, r, m + 1, y, v));
	}
}

int n, m, s[mxn];

signed main() {
//	freopen("in", "r", stdin);
//	freopen("out", "w", stdout);
	
	scanf("%d%d", &n, &m);
	
	for (int i = 1; i <= n; ++i)
		scanf("%d", s + i);
	
	stree::build(1, n, s);
	
	for (int i = 1, k, x, y, z; i <= m; ++i) {
		scanf("%d%d%d", &k, &x, &y);
		
		if (k == 3)
			stree::update(1, 1, n, x, s[x], y), s[x] = y;
		else {
			scanf("%d", &z);
			
			if (k == 1)
				printf("%d\n", stree::queryRnk(1, 1, n, x, y, z) + 1);
			else if (k == 4)
				printf("%d\n", stree::queryPre(1, 1, n, x, y, z));
			else if (k == 5)
				printf("%d\n", stree::queryNxt(1, 1, n, x, y, z));
			else {
				int lt = -inf, rt = +inf, mid, ans;
				
				while (lt <= rt) {
					mid = (lt + rt) >> 1;
					
					if (stree::queryRnk(1, 1, n, x, y, mid) < z)
						lt = mid + 1, ans = mid;
					else
						rt = mid - 1;
				}
				
				printf("%d\n", ans);
			}
		}
	}
}