1. 程式人生 > >樹套樹——樹狀陣列套主席樹

樹套樹——樹狀陣列套主席樹

線段樹套平衡樹是什麼腦殘東西,複雜度就是假的,O(nlog3n)讓人感覺非常不靠譜。所以我們為什麼不用更好寫的樹狀陣列代替線段樹,更好寫的主席樹(權值線段樹)代替平衡樹呢?而且,不僅是好寫,複雜度也是很對的O(nlog2n)啊。

我們來簡單理解一下樹套樹是什麼:
思想其實很簡單,樹狀陣列的每個節點都是一顆權值線段樹,並且每顆權值線段樹維護的資訊都是樹狀陣列式的累加,這樣每次查詢和修改都只需要對logn個線段樹進行操作。

對比一下普通的樹狀陣列和樹套樹:
這裡寫圖片描述
額,,,差不多就是這樣了,確實很簡單吧。

下面這個程式碼是二逼平衡樹的板子,除了修改和查詢k大還有求rank和前驅後繼(所以我在講平衡樹的時候說過維護權值的事線段樹也能做嘛)

Code

#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>

const int maxn = 5e4 + 7;
const int inf = 0x7fffffff;

using namespace std;

int n, m;
int a[maxn];
int val[maxn << 1], tot;
int opt[maxn];
int qa[maxn];
int qb[maxn];
int qc[maxn];

struct
node { int sum; int l, r; } st[maxn * 400]; int cnt; int root[maxn]; int xx[20], cnt1; int yy[20], cnt2; inline int read() { int X = 0; char ch = getchar(); while (ch < '0' || ch > '9') ch = getchar(); while (ch >= '0' && ch <= '9') X = X * 10 + ch - '0', ch = getchar(); return
X; } inline int lowbit(int x) { return x & -x; } void update(int num, int &rt, int l, int r, int x) { st[++cnt] = st[rt]; rt = cnt; st[rt].sum += x; if (l == r) return; int mid = l + r >> 1; if (num <= mid) update(num, st[rt].l, l, mid, x); else update(num, st[rt].r, mid + 1, r, x); } int get_rnk(int l, int r, int x) { if (l == r) return 0; int mid = l + r >> 1; if (x <= mid) { for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].l; for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].l; return get_rnk(l, mid, x); } int d = 0; for (int i = 1; i <= cnt1; i++) d -= st[st[xx[i]].l].sum; for (int i = 1; i <= cnt2; i++) d += st[st[yy[i]].l].sum; for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].r; for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].r; return get_rnk(mid + 1, r, x) + d; } int get_kth(int l, int r, int k) { if (l == r) return val[l]; int d = 0, mid = l + r >> 1; for (int i = 1; i <= cnt1; i++) d -= st[st[xx[i]].l].sum; for (int i = 1; i <= cnt2; i++) d += st[st[yy[i]].l].sum; if (k <= d) { for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].l; for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].l; return get_kth(l, mid, k); } for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].r; for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].r; return get_kth(mid + 1, r, k - d); } inline void add(int i, int x) { int k = lower_bound(val + 1, val + tot + 1, a[i]) - val; for (; i <= n; i += lowbit(i)) update(k, root[i], 1, tot, x); } inline void init_query(int i) { cnt1 = cnt2 = 0; for (int j = qa[i] - 1; j; j -= lowbit(j)) xx[++cnt1] = root[j]; for (int j = qb[i]; j; j -= lowbit(j)) yy[++cnt2] = root[j]; } int main(void) { cin >> n >> m; tot = n; for (int i = 1; i <= n; i++) a[i] = val[i] = read(); for (int i = 1; i <= m; i++) { opt[i] = read(); qa[i] = read(); qb[i] = read(); if (opt[i] != 3) { qc[i] = read(); if (opt[i] != 2) val[++tot] = qc[i]; } else val[++tot] = qb[i]; } sort(val + 1, val + tot + 1); tot = unique(val + 1, val + tot + 1) - val - 1; st[0] = {0, 0, 0}; for (int i = 1; i <= n; i++) add(i, 1); for (int i = 1; i <= m; i++) { if (opt[i] != 3) init_query(i); if (opt[i] != 2 && opt[i] != 3) qc[i] = lower_bound(val + 1, val + tot + 1, qc[i]) - val; if (opt[i] == 1) printf("%d\n", get_rnk(1, tot, qc[i]) + 1); else if (opt[i] == 2) printf("%d\n", get_kth(1, tot, qc[i])); else if (opt[i] == 3) { add(qa[i], -1); a[qa[i]] = qb[i]; add(qa[i], 1); } else if (opt[i] == 4) { int k = get_rnk(1, tot, qc[i]); if (!k) printf("%d\n", -inf); else { init_query(i); printf("%d\n", get_kth(1, tot, k)); } } else { int k = get_rnk(1, tot, qc[i] + 1); if (k > qb[i] - qa[i]) printf("%d\n", inf); else { init_query(i); printf("%d\n", get_kth(1, tot, k + 1)); } } } return 0; }