快速構建 Spring Boot 應用
阿新 • • 發佈:2020-11-20
核心:每操作一個節點,將該節點旋轉到樹根
通過這樣的操作,可以使得平均每個操作的複雜度都是\(O(logn)\)
Splay如何將某個點旋轉到樹根?
定義一種操作:\(Splay(x, k)\) : 將點\(x\)旋轉至點\(k\)下面(k的取值一般是0或1)
旋轉分兩種,如果是一條直線,先轉\(y\),再轉\(x\);如果是一條折線,轉兩次\(x\).
操作
1 插入
(1) 插入一個\(x\),並旋轉到根
\(Splay(x, 0)\)
(2) 把一個序列插入到\(y\)的後面
找到\(y\)的後繼\(z\),將\(y\)轉到根,即\(Splay(y, 0)\)),再將\(z\)轉到\(y\)
2 刪除
(1) 刪除一段\(L\) - \(R\)
找到\(L\)的前驅\(L - 1\),找到\(R\)的後繼\(R + 1\),將\(L - 1\)轉到根節點,將\(R + 1\)轉到根節點下面,\(L\) - \(R\)即為\(R + 1\)的左子樹,刪掉即可.
Splay如何維護資訊?
(1)找第\(k\)個數,維護每個子樹的節點個數
(2)lazy標記,表示該區間是否翻轉
在旋轉之後,pushup維護資訊;在遞迴之前,pushdown下傳lazy標記
注意:Splay 保證中序遍歷是當前序列的順序,且get_k找到的是中序遍歷的第\(k\)
模板題 Acwing 2437
#include <bits/stdc++.h> using namespace std; const int N = 1e5 + 20; int n, m; struct Node { int s[2], p, v; int size, flag; void init(int _v, int _p) { v = _v, p = _p; size = 1; } }tr[N]; int root, idx; void pushup(int x) { tr[x].size = tr[tr[x].s[0]].size + tr[tr[x].s[1]].size + 1; } void pushdown(int x) { if(tr[x].flag) { swap(tr[x].s[0], tr[x].s[1]); tr[tr[x].s[0]].flag ^= 1; tr[tr[x].s[1]].flag ^= 1; tr[x].flag = 0; } } void rotate(int x) { int y = tr[x].p, z = tr[y].p; int k = tr[y].s[1] == x; tr[z].s[tr[z].s[1] == y] = x, tr[x].p = z; tr[y].s[k] = tr[x].s[k ^ 1], tr[tr[x].s[k ^ 1]].p = y; tr[x].s[k ^ 1] = y, tr[y].p = x; pushup(y), pushup(x); } void splay(int x, int k) { while(tr[x].p != k) { int y = tr[x].p, z = tr[y].p; if(z != k) if((tr[y].s[1] == x) ^ (tr[z].s[1] == y)) rotate(x); else rotate(y); rotate(x); } if(!k) root = x; } void insert(int v) { int u = root, p = 0; while(u) p = u, u = tr[u].s[v > tr[u].v]; u = ++ idx; if(p) tr[p].s[v > tr[p].v] = u; tr[u].init(v, p); splay(u, 0); } int get_k(int k) { int u = root; while(1) { pushdown(u); if(tr[tr[u].s[0]].size >= k) u = tr[u].s[0]; else if(tr[tr[u].s[0]].size + 1 == k) return u; else k -= tr[tr[u].s[0]].size + 1, u = tr[u].s[1]; } return -1; } void output(int u) { pushdown(u); if(tr[u].s[0]) output(tr[u].s[0]); if(tr[u].v >= 1 && tr[u].v <= n) printf("%d ",tr[u].v); if(tr[u].s[1]) output(tr[u].s[1]); } int main() { scanf("%d%d", &n, &m); //0和n+1用於防止越界 for(int i = 0; i <= n + 1; ++ i) insert(i); while(m -- ) { int l, r; scanf("%d%d", &l, &r); l = get_k(l), r = get_k(r + 2); splay(l, 0), splay(r, l); tr[tr[r].s[0]].flag ^= 1; } output(root); return 0; }