樹套樹
阿新 • • 發佈:2020-11-16
樹套樹
一種思想,就是一棵樹的節點是另一顆樹。
在外面的叫外層樹,在裡面的叫內層樹。
外層樹一般是, 樹狀陣列 , 線段樹
內層樹一般是 平衡樹 , STL , 線段樹
線段樹套STL
/* * @Author: zhl * @Date: 2020-11-16 12:50:32 */ #include<bits/stdc++.h> #define lo (o<<1) #define ro (o<<1|1) #define mid (l+r>>1) using namespace std; const int N = 5e4 + 10, inf = 1e9; multiset<int>s[N << 2]; int A[N]; void build(int o, int l, int r) { s[o].insert(inf); s[o].insert(-inf); for (int i = l; i <= r; i++) s[o].insert(A[i]); if (l == r)return; build(lo, l, mid); build(ro, mid + 1, r); } void updt(int o, int l, int r, int pos, int v) { s[o].erase(s[o].lower_bound(A[pos])); s[o].insert(v); if (l == r)return; if (pos <= mid) updt(lo, l, mid, pos, v); else updt(ro, mid + 1, r, pos, v); } int query(int o, int l, int r, int L, int R, int v) { if (L <= l and r <= R) return *prev(s[o].lower_bound(v)); int ans = -inf; if (L <= mid)ans = max(ans, query(lo, l, mid, L, R, v)); if (R > mid) ans = max(ans, query(ro, mid + 1, r, L, R, v)); return ans; } int n, m; int main() { scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++)scanf("%d", A + i); build(1, 1, n); while (m--) { int op, a, b, x; scanf("%d", &op); if (op == 1) { scanf("%d%d", &a, &b); updt(1, 1, n, a, b); A[a] = b; } else { scanf("%d%d%d", &a, &b, &x); printf("%d\n", query(1, 1, n, a, b, x)); } } }
線段樹套平衡樹
很多棵樹的時候可以開一個 root 陣列就可以,這樣可以不需要傳引用,因為在splay的時候會更新 root陣列
rotate 不可以任意順序,會有影響
/* * @Author: zhl * @Date: 2020-11-16 13:51:18 */ #include<bits/stdc++.h> #define mid (l+r>>1) #define lo (o<<1) #define ro (o<<1|1) using namespace std; const int N = 2e6 + 10, inf = 0x7fffffff; struct node { int s[2], size, p, v; void init(int _p, int _v) { p = _p; v = _v; size = 1; } }tr[N]; int w[N], n, m, root[N], idx; void push_up(int u) { tr[u].size = tr[tr[u].s[0]].size + tr[tr[u].s[1]].size + 1; } void rotate(int x) { int y = tr[x].p, z = tr[y].p; int k = tr[y].s[1] == x; tr[z].s[tr[z].s[1] == y] = x; tr[x].p = z; tr[y].s[k] = tr[x].s[k ^ 1]; tr[tr[x].s[k ^ 1]].p = y; //草這兩行順序不能換 tr[x].s[k ^ 1] = y; tr[y].p = x; push_up(y), push_up(x); } void splay(int x, int k,int rt) { while (tr[x].p != k) { int y = tr[x].p, z = tr[y].p; if (z != k) { if ((tr[z].s[0] == y) ^ (tr[y].s[0] == x)) rotate(x); else rotate(y); } rotate(x); } if (!k)root[rt] = x; } void insert(int v, int rt) { int u = root[rt], p = 0; while (u) p = u, u = tr[u].s[v > tr[u].v]; u = ++idx; if (p)tr[p].s[v > tr[p].v] = u; tr[u].init(p, v); splay(u, 0, rt); } int get_rank(int v, int rt) { int u = root[rt], res = 0; while (u) { if (v > tr[u].v) res += tr[tr[u].s[0]].size + 1, u = tr[u].s[1]; else u = tr[u].s[0]; } return res; } void build(int o, int l, int r) { insert(-inf, o); insert(inf, o); for (int i = l; i <= r; i++) { insert(w[i], o); } if (l == r)return; build(lo, l, mid); build(ro, mid + 1, r); } int query_rank(int o, int l, int r, int L, int R, int x) { if (L <= l and r <= R)return get_rank(x, o) - 1; int ans = 0; if (L <= mid)ans += query_rank(lo, l, mid, L, R, x); if (R > mid) ans += query_rank(ro, mid + 1, r, L, R, x); return ans; } void updt(int o, int l, int r, int pos, int v){ int u = root[o]; while (u) { if (tr[u].v == w[pos])break; if (w[pos] > tr[u].v)u = tr[u].s[1]; if (w[pos] < tr[u].v) u = tr[u].s[0]; } splay(u, 0, o); int ls = tr[u].s[0], rs = tr[u].s[1]; while (tr[ls].s[1]) ls = tr[ls].s[1]; while (tr[rs].s[0]) rs = tr[rs].s[0]; splay(ls, 0, o); splay(rs, ls, o); tr[rs].s[0] = 0; push_up(rs); push_up(ls); insert(v, o); if (l == r)return; //不要忘記結束條件 if (pos <= mid) { updt(lo, l, mid, pos, v); } else { updt(ro, mid + 1, r, pos, v); } } int get_pre(int x,int rt) { int u = root[rt], res = -inf; while (u) { if (tr[u].v >= x) u = tr[u].s[0]; else res = tr[u].v, u = tr[u].s[1]; } return res; } int get_suc(int x,int rt) { int u = root[rt], res = -inf; while (u) { if (tr[u].v <= x) u = tr[u].s[1]; else res = tr[u].v, u = tr[u].s[0]; } return res; } int query_pre(int o, int l, int r, int L, int R, int x) { if (L <= l and r <= R)return get_pre(x, o); int ans = -inf; if (L <= mid)ans = max(ans, query_pre(lo, l, mid, L, R, x)); if (R > mid) ans = max(ans, query_pre(ro, mid + 1, r, L, R, x)); return ans; } int query_suc(int o, int l, int r, int L, int R, int x) { if (L <= l and r <= R)return get_suc(x, o); int ans = inf; if (L <= mid)ans = min(ans, query_suc(lo, l, mid, L, R, x)); if (R > mid) ans = min(ans, query_suc(ro, mid + 1, r, L, R, x)); return ans; } int main() { scanf("%d%d", &n, &m); for (int i = 1; i <= n; i++)scanf("%d", w + i); build(1, 1, n); while (m--) { int op, a, b, k, pos; scanf("%d", &op); if (op == 1) { scanf("%d%d%d", &a, &b, &k); printf("%d\n", query_rank(1, 1, n, a, b, k) + 1); } else if (op == 2) { scanf("%d%d%d", &a, &b, &k); int l = 0, r = 1e8; while (l < r) { int m = l + r + 1 >> 1; if (query_rank(1, 1, n, a, b, m) + 1 <= k) { l = m; } else { r = m - 1; } } printf("%d\n", r); } else if (op == 3) { scanf("%d%d", &pos, &k); updt(1, 1, n, pos, k); w[pos] = k; } else if (op == 4) { scanf("%d%d%d", &a, &b, &k); printf("%d\n", query_pre(1, 1, n, a, b, k)); } else if (op == 5) { scanf("%d%d%d", &a, &b, &k); printf("%d\n", query_suc(1, 1, n, a, b, k)); } } }