樹套樹——樹狀陣列套主席樹
阿新 • • 發佈:2019-02-14
線段樹套平衡樹是什麼腦殘東西,複雜度就是假的,讓人感覺非常不靠譜。所以我們為什麼不用更好寫的樹狀陣列代替線段樹,更好寫的主席樹(權值線段樹)代替平衡樹呢?而且,不僅是好寫,複雜度也是很對的啊。
我們來簡單理解一下樹套樹是什麼:
思想其實很簡單,樹狀陣列的每個節點都是一顆權值線段樹,並且每顆權值線段樹維護的資訊都是樹狀陣列式的累加,這樣每次查詢和修改都只需要對logn個線段樹進行操作。
對比一下普通的樹狀陣列和樹套樹:
額,,,差不多就是這樣了,確實很簡單吧。
下面這個程式碼是二逼平衡樹的板子,除了修改和查詢k大還有求rank和前驅後繼(所以我在講平衡樹的時候說過維護權值的事線段樹也能做嘛)
Code
#include <iostream>
#include <algorithm>
#include <cstring>
#include <cstdio>
const int maxn = 5e4 + 7;
const int inf = 0x7fffffff;
using namespace std;
int n, m;
int a[maxn];
int val[maxn << 1], tot;
int opt[maxn];
int qa[maxn];
int qb[maxn];
int qc[maxn];
struct node {
int sum;
int l, r;
} st[maxn * 400];
int cnt;
int root[maxn];
int xx[20], cnt1;
int yy[20], cnt2;
inline int read()
{
int X = 0; char ch = getchar();
while (ch < '0' || ch > '9') ch = getchar();
while (ch >= '0' && ch <= '9') X = X * 10 + ch - '0', ch = getchar();
return X;
}
inline int lowbit(int x)
{
return x & -x;
}
void update(int num, int &rt, int l, int r, int x)
{
st[++cnt] = st[rt];
rt = cnt;
st[rt].sum += x;
if (l == r) return;
int mid = l + r >> 1;
if (num <= mid) update(num, st[rt].l, l, mid, x);
else update(num, st[rt].r, mid + 1, r, x);
}
int get_rnk(int l, int r, int x)
{
if (l == r) return 0;
int mid = l + r >> 1;
if (x <= mid) {
for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].l;
for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].l;
return get_rnk(l, mid, x);
}
int d = 0;
for (int i = 1; i <= cnt1; i++) d -= st[st[xx[i]].l].sum;
for (int i = 1; i <= cnt2; i++) d += st[st[yy[i]].l].sum;
for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].r;
for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].r;
return get_rnk(mid + 1, r, x) + d;
}
int get_kth(int l, int r, int k)
{
if (l == r) return val[l];
int d = 0, mid = l + r >> 1;
for (int i = 1; i <= cnt1; i++) d -= st[st[xx[i]].l].sum;
for (int i = 1; i <= cnt2; i++) d += st[st[yy[i]].l].sum;
if (k <= d) {
for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].l;
for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].l;
return get_kth(l, mid, k);
}
for (int i = 1; i <= cnt1; i++) xx[i] = st[xx[i]].r;
for (int i = 1; i <= cnt2; i++) yy[i] = st[yy[i]].r;
return get_kth(mid + 1, r, k - d);
}
inline void add(int i, int x)
{
int k = lower_bound(val + 1, val + tot + 1, a[i]) - val;
for (; i <= n; i += lowbit(i)) update(k, root[i], 1, tot, x);
}
inline void init_query(int i)
{
cnt1 = cnt2 = 0;
for (int j = qa[i] - 1; j; j -= lowbit(j)) xx[++cnt1] = root[j];
for (int j = qb[i]; j; j -= lowbit(j)) yy[++cnt2] = root[j];
}
int main(void)
{
cin >> n >> m;
tot = n;
for (int i = 1; i <= n; i++) a[i] = val[i] = read();
for (int i = 1; i <= m; i++) {
opt[i] = read();
qa[i] = read();
qb[i] = read();
if (opt[i] != 3) {
qc[i] = read();
if (opt[i] != 2) val[++tot] = qc[i];
}
else val[++tot] = qb[i];
}
sort(val + 1, val + tot + 1);
tot = unique(val + 1, val + tot + 1) - val - 1;
st[0] = {0, 0, 0};
for (int i = 1; i <= n; i++) add(i, 1);
for (int i = 1; i <= m; i++) {
if (opt[i] != 3) init_query(i);
if (opt[i] != 2 && opt[i] != 3) qc[i] = lower_bound(val + 1, val + tot + 1, qc[i]) - val;
if (opt[i] == 1) printf("%d\n", get_rnk(1, tot, qc[i]) + 1);
else if (opt[i] == 2) printf("%d\n", get_kth(1, tot, qc[i]));
else if (opt[i] == 3) {
add(qa[i], -1);
a[qa[i]] = qb[i];
add(qa[i], 1);
}
else if (opt[i] == 4) {
int k = get_rnk(1, tot, qc[i]);
if (!k) printf("%d\n", -inf);
else {
init_query(i);
printf("%d\n", get_kth(1, tot, k));
}
}
else {
int k = get_rnk(1, tot, qc[i] + 1);
if (k > qb[i] - qa[i]) printf("%d\n", inf);
else {
init_query(i);
printf("%d\n", get_kth(1, tot, k + 1));
}
}
}
return 0;
}