Splay 的錯誤記錄和模板
阿新 • • 發佈:2020-10-08
連結:Splay 的簡介和相關題目
錯誤記錄(除錯用)
inline int find(int k) { int p = root; while (1) { pushdown(p);//Attention!! if (k <= siz[son[p][0]]) { //WA : p = siz[son[p][0]]; continue; p = son[p][0]; continue; } int tmp = siz[son[p][0]] + 1; if (k <= tmp) break; k -= tmp; p = son[p][1]; } return p; }
//printf("%d\n", sum[son[son[root][0]][1]]);
printf("%d\n", sum[son[son[root][1]][0]]);
inline void Swap(int x, int y) {//y is next to x(..., x, y, ...) splay(x, 0); splay(y, x); int rt1 = son[x][0], rt2 = son[y][1]; son[x][0] = 0, son[y][1] = x; son[x][1] = rt2, son[y][0] = rt1; fa[x] = y, fa[y] = 0; fa[rt1] = y, fa[rt2] = x; pushup(x); pushup(y); root = y;//Attention! }
int build(int L, int R, int faa) {
if (L > R) return 0;
int mid = (L + R) >> 1;
int cur = num[mid];
fa[cur] = faa;//Attention!!
son[cur][0] = build(L, mid - 1, cur);
son[cur][1] = build(mid + 1, R, cur);
pushup(cur);
return cur;
}
void print(int cur) { if (!cur) return ; pushdown(cur);//Attention! print(son[cur][0]); if (cur <= n) printf("%d\n", cur);//注意順序:先左,再中(本身),再右 print(son[cur][1]); }
inline void Cut(int cur) {
int faa = fa[cur]; printf("Cut : faa = %d\n", faa);
son[faa][get_which(cur)] = 0;
fa[cur] = 0;
//pushup(faa);
//splay(faa, 0);(與main函式中的語句衝突)
}
inline int find(int k) {
//k++; 不能加此語句,因為我設定哨兵節點的 siz 為0
int p = rt[nwid];//維護多棵Splay
while (1) {
if (k <= siz[son[p][0]]) {
p = son[p][0];
continue;
}
int tmp = siz[son[p][0]];
k -= tmp;
if (k <= (r[p] - l[p] + 1)) {
split(p, l[p] + k - 1);//拆點,非提取區間
break;
}
k -= (r[p] - l[p] + 1);
p = son[p][1];
}
return p;
}
inline void splay(int cur, int goal) {
for (register int faa = fa[cur]; faa != goal; rotate(cur), faa = fa[cur])
if (fa[faa] != goal) rotate(get_which(cur) == get_which(faa) ? faa : cur);
//"fa[faa] != goal" !!! not "fa[faa]"!!
pushup(cur);
if (!goal) rt[nwid] = cur;
}
inline int get_pre(int cur) {
splay(cur, 0);//Attention!!(或者在主函式裡完成這一步)
int p = son[cur][0];
while (son[p][1]) p = son[p][1];
return p;
}
inline void rotate(int cur) {
int faa = fa[cur], fafa = fa[faa];
bool flag = get_which(cur);//Bug
fa[cur] = fafa; if (fafa) son[fafa][get_which(faa)] = cur;
son[faa][flag] = son[cur][flag ^ 1]; if (son[cur][flag ^ 1]) fa[son[cur][flag ^ 1]] = faa;//bug
son[cur][flag ^ 1] = faa; fa[faa] = cur;
pushup(faa);
}
模板(以區間翻轉,區間推平,區間求和,區間查詢最大子段和,單點查詢,插入一段序列,刪除一段序列為例,含回收廢點)(除錯用)
int n, m;
int son[N][2], fa[N], root, ttot, val[N];
int sum[N], lmx[N], rmx[N], mx[N], siz[N];
bool tag_sam[N], tag_rev[N];
int bin[N], btot;
inline int get_new() {
int cur = btot ? bin[btot--] : ++ttot;
son[cur][0] = son[cur][1] = fa[cur] = val[cur] = 0;
sum[cur] = lmx[cur] = rmx[cur] = mx[cur] = siz[cur] = 0;
tag_sam[cur] = 0, tag_rev[cur] = 0;
return cur;
}
inline bool get_which(int cur) {
return son[fa[cur]][1] == cur;
}
inline void pushup(int cur) {
//mx[cur]:必須非空
//lmx[cur], rmx[cur]:可以為空(=0)
int ls = son[cur][0], rs = son[cur][1];
siz[cur] = siz[ls] + siz[rs] + 1;
sum[cur] = val[cur] + sum[ls] + sum[rs];
mx[cur] = max(max(mx[ls], mx[rs]), rmx[ls] + val[cur] + lmx[rs]);
lmx[cur] = max(lmx[ls], sum[ls] + val[cur] + lmx[rs]);
rmx[cur] = max(rmx[rs], sum[rs] + val[cur] + rmx[ls]);
}
inline void push_rev(int cur) {
if (!cur) return ;
tag_rev[cur] ^= 1;
swap(son[cur][0], son[cur][1]);
swap(lmx[cur], rmx[cur]);
}
inline void push_sam(int cur, int vall) {
if (!cur) return ;
tag_sam[cur] = true;
val[cur] = vall;
sum[cur] = siz[cur] * vall;
if (vall <= 0) lmx[cur] = rmx[cur] = 0, mx[cur] = vall;
else mx[cur] = lmx[cur] = rmx[cur] = sum[cur];
}
inline void pushdown(int cur) {
if (tag_sam[cur])
push_sam(son[cur][0], val[cur]), push_sam(son[cur][1], val[cur]), tag_sam[cur] = 0;
if (tag_rev[cur])
push_rev(son[cur][0]), push_rev(son[cur][1]), tag_rev[cur] = 0;
}
inline void rotate(int cur) {
int faa = fa[cur], fafa = fa[faa];
bool flag = get_which(cur);
fa[cur] = fafa; if (fafa) son[fafa][get_which(faa)] = cur;
son[faa][flag] = son[cur][flag ^ 1]; if (son[cur][flag ^ 1]) fa[son[cur][flag ^ 1]] = faa;
son[cur][flag ^ 1] = faa; fa[faa] = cur;
pushup(faa);
}
int stk[N], stop;
inline void splay(int cur, int goal) {
int p = cur;
while (p != goal) stk[++stop] = p, p = fa[p];
if (goal) pushdown(goal);
while (stop) pushdown(stk[stop--]);
for (register int faa = fa[cur]; faa != goal; rotate(cur), faa = fa[cur])
if (fa[faa] != goal) rotate(get_which(cur) == get_which(faa) ? faa : cur);
pushup(cur);
if (goal == 0) root = cur;
}
inline int find(int k) {
k++;
int p = root;
while (1) {
pushdown(p);
if (k <= siz[son[p][0]]) {
p = son[p][0];
continue;
}
int tmp = siz[son[p][0]] + 1;
if (k <= tmp) break;
k -= tmp;
p = son[p][1];
}
return p;
}
inline void split(int x, int y) {
splay(x, 0);
splay(y, x);
}
int num[N];
int build(int L, int R, int faa) {
if (L > R) return 0;
int mid = (L + R) >> 1;
int cur = get_new();
fa[cur] = faa;
sum[cur] = mx[cur] = val[cur] = num[mid];//必須非空
lmx[cur] = rmx[cur] = max(0, val[cur]);//可以為空
son[cur][0] = build(L, mid - 1, cur);
son[cur][1] = build(mid + 1, R, cur);
pushup(cur);
return cur;
}
inline void clear_tre(int cur) {
if (!cur) return;
bin[++btot] = cur;
clear_tre(son[cur][0]); clear_tre(son[cur][1]);
}
int main() {
read(n); read(m);
mx[0] = -inf;
num[1] = -inf;
for (register int i = 1; i <= n; ++i) read(num[i + 1]);
num[n + 2] = -inf;
root = build(1, n + 2, 0);
char opt[10];
for (register int i = 1; i <= m; ++i) {
scanf("%s", opt);
if (opt[0] == 'G' && opt[3] == '-') {//GET-SUM
register int x, nn;
read(x); read(nn);
if (!nn) {
printf("0\n");
continue;
}
nn = find(x + nn); x = find(x - 1);
split(x, nn);
printf("%d\n", sum[son[son[root][1]][0]]);
} else if (opt[0] == 'G') {//GET
register int x;
read(x);
x = find(x);
printf("%d\n", val[x]);
} else if (opt[0] == 'M' && opt[2] == 'X') {//MAX-SUM
register int x, nn;
read(x); read(nn);
nn = find(x + nn); x = find(x - 1);
split(x, nn);
printf("%d\n", mx[son[son[root][1]][0]]);
} else if (opt[0] == 'R') {//REVERSE
register int x, nn;
read(x); read(nn);
nn = find(x + nn); x = find(x - 1);
split(x, nn);
push_rev(son[son[root][1]][0]);
pushup(son[root][1]); pushup(root);
} else if (opt[0] == 'M') {//MAKE-SAME
register int x, nn, t;
read(x); read(nn); read(t);
nn = find(x + nn); x = find(x - 1);
split(x, nn);
push_sam(son[son[root][1]][0], t);
pushup(son[root][1]); pushup(root);
} else if (opt[0] == 'D') {//DELETE
register int x, nn;
read(x); read(nn);
nn = find(x + nn); x = find(x - 1);
split(x, nn);
clear_tre(son[son[root][1]][0]);
son[son[root][1]][0] = 0;
pushup(son[root][1]); pushup(root);
} else if (opt[0] == 'I') {//INSERT
register int x, nn, nxt;
read(x); read(nn);
for (register int i = 1; i <= nn; ++i) read(num[i]);
nxt = find(x + 1); x = find(x);
split(x, nxt);
int rt = build(1, nn, nxt);
son[nxt][0] = rt;
pushup(nxt); pushup(root);
}
}
return 0;
}
Continued...