學習筆記:Splay
阿新 • • 發佈:2020-09-11
程式碼適中、非常靈活的平衡樹。
需要前置:二叉搜尋樹。
一些基礎的函式:
int idx, ch[N][2], cnt[N], sz[N], fa[N]; /* idx 是節點計數, ch[i][0 / 1] 是 i 節點的左右子樹節點 cnt[i] 是 i 節點的數量 sz[i] 是 i 節點子樹的大小 fa[i] 是 i 的父親 */ // pushup void inline pushup(int p) { sz[p] = sz[ch[p][0]] + cnt[p] + sz[ch[p][1]]; } // 判斷 p 是 fa[p] 左兒子還是右兒子 (0 / 1) bool inline get(int p) { return p == ch[fa[p]][1]; } // 清空一個節點 void inline clear(int p) { ch[p][0] = ch[p][1] = val[p] = cnt[p] = sz[p] = fa[p] = 0; }
\(\text{Pushup}\) 要放在旋轉的最後。
\(\text{Pushdown}\) 只要遞迴就推下去。
旋轉的意義:保持中序遍歷不變,調整樹高。
這樣旋轉後,在改變樹形結構的基礎上發現中序遍歷保持不變。
void inline rotate(int x) { int y = fa[x], z = fa[y], k = get(x); ch[y][k] = ch[x][!k], fa[ch[x][!k]] = y; ch[x][!k] = y, fa[y] = x; fa[x] = z; if (z) ch[z][y == ch[z][1]] = x; pushup(y); pushup(x); }
以下所有介紹的操作都是 Splay 的獨特的操作,剩下的二叉搜尋樹就有了。
複雜度的保持 & 核心思想:
每次操作完的點,均將這個點旋轉(Splay)到樹根。
感性理解的好處:每一次用到,後面還有可能再用到。
有嚴謹的證明,結論是若操作 \(m\) 次,總複雜度是 \(O(m \log n)\),平均意義每次操作都是 \(O(\log)\) 的。
Splay 翻轉
定義函式 \(splay(x, k)\) 表示將點 \(x\) 旋轉至 \(k\) 下面。
\(y = fa_x, z = fa_y\)。
迭代:
- 如果 \(z\) 不存在,轉一次 \(x\) 即可。
- 若 \(z, y, x\)
- 否則是折線,就轉兩次 \(x\)
只有這麼轉複雜度才是對的,不能隨便轉,要背一下)
void inline splay(int p) {
for (int f = fa[p]; f = fa[p]; rotate(p))
if (fa[f]) rotate(get(p) == get(f) ? f : p);
rt = p;
}
以下標為鍵:將一段序列插入到 y 的後面
- 找到 \(y\) 的後繼 \(z\)
- 將 \(y\) 旋轉到根 \(splay(y, 0)\)
- 將 \(z\) 轉到 \(y\) 的下面 \(splay(z, y)\)
這樣 \(z\) 一定沒有左子樹,直接把一段序列構造好的樹節點賦值成 \(z\) 的左子樹就行了。
以下標為鍵:操作一段
刪除序列的 \([l, r]\)
\(splay(kth(l - 1), 0), splay(kth(r+1), l - 1)\),這樣 \([l, r]\) 之間所有的點組成了以 \(r + 1\) 的左子樹,這樣直接就可以在 \(kth(r + 1)\) 的左兒子這個節點打 \(tag\) 就行了。
板子
#include <cstdio>
#include <iostream>
using namespace std;
const int N = 100005;
int n, m, rt;
int idx, ch[N][2], val[N], cnt[N], sz[N], fa[N];
void inline update(int p) {
sz[p] = sz[ch[p][0]] + cnt[p] + sz[ch[p][1]];
}
bool inline get(int p) {
return p == ch[fa[p]][1];
}
void inline clear(int p) {
ch[p][0] = ch[p][1] = val[p] = cnt[p] = sz[p] = fa[p] = 0;
}
void inline rotate(int x) {
int y = fa[x], z = fa[y], k = get(x);
ch[y][k] = ch[x][!k], fa[ch[x][!k]] = y;
ch[x][!k] = y, fa[y] = x;
fa[x] = z;
if (z) ch[z][y == ch[z][1]] = x;
update(y); update(x);
}
void inline splay(int p) {
for (int f = fa[p]; f = fa[p]; rotate(p))
if (fa[f]) rotate(get(p) == get(f) ? f : p);
rt = p;
}
void insert(int &p, int x, int f) {
if (!p) {
p = ++idx, sz[p] = cnt[p] = 1, fa[p] = f, val[p] = x;
if (f) ch[f][x > val[f]] = p, update(f), splay(p);
} else if (val[p] == x) cnt[p]++, sz[p]++, update(f), splay(p);
else insert(ch[p][x > val[p]], x, p);
}
int kth(int p, int k) {
if (k <= sz[ch[p][0]]) return kth(ch[p][0], k);
else if (k <= sz[ch[p][0]] + cnt[p]) { splay(p); return val[p]; }
else return kth(ch[p][1], k - sz[ch[p][0]] - cnt[p]);
}
int getRank(int p, int k) {
int res = 0;
if (k < val[p]) return getRank(ch[p][0], k);
else if (k == val[p]) { res = sz[ch[p][0]] + 1; splay(p); return res; }
else { res += sz[ch[p][0]] + cnt[p]; return res + getRank(ch[p][1], k); }
}
int inline pre() {
int p = ch[rt][0];
while (ch[p][1]) p = ch[p][1];
splay(p);
return p;
}
int inline nxt() {
int p = ch[rt][1];
while (ch[p][0]) p = ch[p][0];
splay(p);
return p;
}
void inline del(int k) {
getRank(rt, k);
if (cnt[rt] > 1) cnt[rt]--, sz[rt]--;
else if (!ch[rt][0] && !ch[rt][1]) {
clear(rt), rt = 0;
} else if (!ch[rt][0]) fa[rt = ch[rt][1]] = 0;
else if (!ch[rt][1]) fa[rt = ch[rt][0]] = 0;
else {
int p = rt, x = pre();
splay(x); ch[x][1] = ch[p][1], fa[ch[x][1]] = x;
clear(p); update(rt);
}
}
int main() {
scanf("%d", &m);
while (m--) {
int opt, x; scanf("%d%d", &opt, &x);
if (opt == 1) {
insert(rt, x, 0);
} else if (opt == 2) {
del(x);
} else if (opt == 3) {
insert(rt, x, 0);
printf("%d\n", getRank(rt, x));
del(x);
} else if (opt == 4) {
printf("%d\n", kth(rt, x));
} else if (opt == 5) {
insert(rt, x, 0);
printf("%d\n", val[pre()]);
del(x);
} else if (opt == 6) {
insert(rt, x, 0);
printf("%d\n", val[nxt()]);
del(x);
}
}
}
#include <iostream>
#include <cstdio>
#define ls ch[p][0]
#define rs ch[p][1]
#define get(x) x == ch[fa[x]][1]
using namespace std;
const int N = 100005;
int n, m, val[N], ch[N][2], sz[N], fa[N], rev[N], rt, idx;
void inline pushup(int p) {
sz[p] = sz[ls] + sz[rs] + 1;
}
void inline reverse(int p) {
swap(ls, rs), rev[p] ^= 1;
}
void inline pushdown(int p) {
if (rev[p]) {
if (ls) reverse(ls);
if (rs) reverse(rs);
rev[p] = 0;
}
}
void inline rotate(int x) {
int y = fa[x], z = fa[y], k = get(x);
ch[y][k] = ch[x][!k], fa[ch[x][!k]] = y;
ch[x][!k] = y, fa[y] = x;
fa[x] = z;
if (z) ch[z][y == ch[z][1]] = x;
pushup(y), pushup(x);
}
void inline splay(int x, int k) {
for (int f = fa[x]; (f = fa[x]) != k; rotate(x)) {
if (fa[f]) rotate(get(x) == get(f) ? f : x);
}
if (!k) rt = x;
}
void build(int &p, int l, int r, int f) {
if (l > r) return;
p = ++idx;
int mid = (l + r) >> 1; val[p] = mid, fa[p] = f;
if (l < r) {
build(ch[p][0], l, mid - 1, p);
build(ch[p][1], mid + 1, r, p);
}
pushup(p);
}
void print(int p) {
if (!p) return;
pushdown(p);
print(ch[p][0]);
if (val[p] && val[p] <= n) printf("%d ", val[p]);
print(ch[p][1]);
}
int inline kth(int p, int k) {
pushdown(p);
if (k <= sz[ch[p][0]]) return kth(ch[p][0], k);
else if (k == sz[ch[p][0]] + 1) {
splay(p, 0);
return p;
} else return kth(ch[p][1], k - sz[ch[p][0]] - 1);
}
int main() {
scanf("%d%d", &n, &m);
build(rt, 0, n + 1, 0);
while (m--) {
int l, r; scanf("%d%d", &l, &r);
int x = kth(rt, l), y = kth(rt, r + 2);
splay(x, 0); splay(y, x);
reverse(ch[y][0]);
}
print(rt);
return 0;
}