1. 程式人生 > 其它 >luogu P4332 [SHOI2014]三叉神經樹

luogu P4332 [SHOI2014]三叉神經樹

https://www.luogu.com.cn/problem/P4332

成功複習了LCT

首先發現狀態改變的一定是葉子向上的一條路徑
我們記錄\(sum[u]\)表示\(u\)節點有幾個兒子是\(1\)\(val[u]=sum[u]>1\)

這樣維護最大值和最小值,每次\(access\)一下然後在\(Splay\)上二分可以做到\(nlog^2n\)
考慮記錄下來\(Splay\)上最深的\(sum\)不是\(1/2\)的點,然後每次向上修改到這個點即可

程式碼實現有些細節要注意

code:


#include<bits/stdc++.h>
#define N 2500050
using namespace std;
inline int rd() {
    int x = 0; char ch = getchar();
    for(; ch < '0' || ch > '9' ;) ch = getchar();
    for(; ch >= '0' && ch <= '9'; ) x = (x << 3) + (x << 1) + (ch - '0'), ch = getchar();
    return x;
}
struct LCT {
    #define ls ch[x][0]
    #define rs ch[x][1]
    int ch[N][2], val[N], id[N][3], tg[N], sum[N], fa[N];
    int get(int x) {return ch[fa[x]][1] == x; }
    int nrt(int x) {return ch[fa[x]][0] == x || ch[fa[x]][1] == x;}
    void update(int x) {
        id[x][1] = id[rs][1], id[x][2] = id[rs][2];
        if(!id[x][1]) {
            if(sum[x] != 1) id[x][1] = x;
            else id[x][1] = id[ls][1];
        }
        if(!id[x][2]) {
            if(sum[x] != 2) id[x][2] = x;
            else id[x][2] = id[ls][2];
        }
    }
    void padd(int x, int o) {
        tg[x] += o, sum[x] += o, val[x] = sum[x] > 1;
        swap(id[x][1], id[x][2]);
    }
    void pushdown(int x) {
        if(tg[x]) {
            padd(ls, tg[x]), padd(rs, tg[x]);
            tg[x] = 0; 
        }
    }

    void rotate(int x) {
        int f = fa[x], gf = fa[f], k = get(x);
        if(nrt(f)) ch[gf][get(f)] = x; fa[x] = gf;
        ch[f][k] = ch[x][!k]; if(ch[x][!k]) fa[ch[x][!k]] = f;
        ch[x][!k] = f, fa[f] = x;
        update(f), update(x);
    }
    void pushall(int x) {
        if(nrt(x)) pushall(fa[x]);
        pushdown(x);
    }
    void splay(int x) {
        pushall(x);
        while(nrt(x)) {
            int f = fa[x];
            if(nrt(f)) rotate(get(f) == get(x)? f : x);
            rotate(x);
        }
    }
    void access(int x) {
        for(int y = 0; x; y = x, x = fa[x]) {
            splay(x), rs = y, update(x);
        }
    }
} T;

int n, m;
vector<int> g[N];
void dfs(int u, int fa) {
    T.sum[u] = 0;
    for(int v : g[u]) {
        if(v == fa) continue;
        T.fa[v] = u;
        dfs(v, u);
        T.sum[u] += T.val[v];
    }
    if(u <= n) T.val[u] = T.sum[u] > 1;
}
int main() {
    n = rd();
    for(int i = 1; i <= n; i ++) {
        for(int j = 1; j <= 3; j ++) {
            int x;
            x = rd();
            g[x].push_back(i), g[i].push_back(x);
        }
    }
    for(int i = n + 1; i <= 3 * n + 1; i ++) T.val[i] = rd();

    dfs(1, 1);
    m = rd();
    while(m --) {
        int x, y;
        y = rd(), x = T.fa[y];
        int o = T.val[y]? - 1 : 1;
        T.access(x), T.splay(x);
        int k = o;
        if(k == -1) k = 2;
        int z = T.id[x][k];
        if(z) {
            T.splay(z);
            T.padd(T.ch[z][1], o), T.update(T.ch[z][1]);
            T.sum[z] += o, T.val[z] = T.sum[z] > 1; T.update(z);
        } else T.padd(x, o), T.update(x);
        T.val[y] ^= 1; T.splay(1); 
        printf("%d\n", T.val[1]);
    }
    return 0;
}