luogu P4332 [SHOI2014]三叉神經樹
阿新 • • 發佈:2021-12-17
https://www.luogu.com.cn/problem/P4332
成功複習了LCT
首先發現狀態改變的一定是葉子向上的一條路徑
我們記錄\(sum[u]\)表示\(u\)節點有幾個兒子是\(1\),\(val[u]=sum[u]>1\)
這樣維護最大值和最小值,每次\(access\)一下然後在\(Splay\)上二分可以做到\(nlog^2n\),
考慮記錄下來\(Splay\)上最深的\(sum\)不是\(1/2\)的點,然後每次向上修改到這個點即可
程式碼實現有些細節要注意
code:
#include<bits/stdc++.h> #define N 2500050 using namespace std; inline int rd() { int x = 0; char ch = getchar(); for(; ch < '0' || ch > '9' ;) ch = getchar(); for(; ch >= '0' && ch <= '9'; ) x = (x << 3) + (x << 1) + (ch - '0'), ch = getchar(); return x; } struct LCT { #define ls ch[x][0] #define rs ch[x][1] int ch[N][2], val[N], id[N][3], tg[N], sum[N], fa[N]; int get(int x) {return ch[fa[x]][1] == x; } int nrt(int x) {return ch[fa[x]][0] == x || ch[fa[x]][1] == x;} void update(int x) { id[x][1] = id[rs][1], id[x][2] = id[rs][2]; if(!id[x][1]) { if(sum[x] != 1) id[x][1] = x; else id[x][1] = id[ls][1]; } if(!id[x][2]) { if(sum[x] != 2) id[x][2] = x; else id[x][2] = id[ls][2]; } } void padd(int x, int o) { tg[x] += o, sum[x] += o, val[x] = sum[x] > 1; swap(id[x][1], id[x][2]); } void pushdown(int x) { if(tg[x]) { padd(ls, tg[x]), padd(rs, tg[x]); tg[x] = 0; } } void rotate(int x) { int f = fa[x], gf = fa[f], k = get(x); if(nrt(f)) ch[gf][get(f)] = x; fa[x] = gf; ch[f][k] = ch[x][!k]; if(ch[x][!k]) fa[ch[x][!k]] = f; ch[x][!k] = f, fa[f] = x; update(f), update(x); } void pushall(int x) { if(nrt(x)) pushall(fa[x]); pushdown(x); } void splay(int x) { pushall(x); while(nrt(x)) { int f = fa[x]; if(nrt(f)) rotate(get(f) == get(x)? f : x); rotate(x); } } void access(int x) { for(int y = 0; x; y = x, x = fa[x]) { splay(x), rs = y, update(x); } } } T; int n, m; vector<int> g[N]; void dfs(int u, int fa) { T.sum[u] = 0; for(int v : g[u]) { if(v == fa) continue; T.fa[v] = u; dfs(v, u); T.sum[u] += T.val[v]; } if(u <= n) T.val[u] = T.sum[u] > 1; } int main() { n = rd(); for(int i = 1; i <= n; i ++) { for(int j = 1; j <= 3; j ++) { int x; x = rd(); g[x].push_back(i), g[i].push_back(x); } } for(int i = n + 1; i <= 3 * n + 1; i ++) T.val[i] = rd(); dfs(1, 1); m = rd(); while(m --) { int x, y; y = rd(), x = T.fa[y]; int o = T.val[y]? - 1 : 1; T.access(x), T.splay(x); int k = o; if(k == -1) k = 2; int z = T.id[x][k]; if(z) { T.splay(z); T.padd(T.ch[z][1], o), T.update(T.ch[z][1]); T.sum[z] += o, T.val[z] = T.sum[z] > 1; T.update(z); } else T.padd(x, o), T.update(x); T.val[y] ^= 1; T.splay(1); printf("%d\n", T.val[1]); } return 0; }