1. 程式人生 > 其它 >Solution -「PKUWC 2018」「洛谷 P5298」Minimax

Solution -「PKUWC 2018」「洛谷 P5298」Minimax

\(\mathscr{Description}\)

  Link.

  給定一棵二叉樹,每片葉子有一個權值,所有權值互不相同。每個非葉結點 \(u\) 有一個概率 \(p_u\in(0,1)\),表示 \(u\) 的權值以 \(p_u\) 的概率取兒子最大權值,以 \(1-p_u\) 的概率取兒子最小權值。求根節點取到每種權值的概率(以一定形式壓縮輸出)。答案模 \(998244353\)

\(\mathscr{Solution}\)

  令 \(f(u,i)\) 表示 \(u\) 處取到全域性第 \(i\) 大的權值的概率,設 \(u\) 的左右兒子為 \(l,r\),顯然

\[f(u,i)=f(l,i)\left(p_u\sum_{j<i}f(r,j)+(1-p_u)\sum_{j>i}f(r,j)\right)+f(r,i)\left(p_u\sum_{j<i}f(l,j)+(1-p_u)\sum_{j>i}f(l,j)\right). \]

這是個左右區間的交叉貢獻,用權值線段樹維護 \(f(u)\)

,我們居然能直接將 \(f(l)\)\(f(r)\) 的樹“混合”直接求出 \(f(u)\)

  注意,例如左對右有貢獻,直接在右子樹上打乘法或加法之類的標記是不良的——不同時間戳下的標記合併方式不同。因而,我們可以暴力地保證,當且僅當線段樹結點 \(u\) 的後代不會被打上當前時間戳的標記(即,\(u\) 子樹內的貢獻因子都一樣了),我們才給它打乘法標記;否則僅累加係數遞迴傳遞。

  所有貢獻完成後,我們再把 \(f'(l)\)\(f'(r)\) 合併(對應點貢獻相加),就得到了 \(f(u)\)。本質上就是做了兩次線段樹合併,所以複雜度是 \(\mathcal O(n\log n)\)

  那麼這件事情告訴我們線段樹無所不能,所有看似暴力的 DP 都拿來試一試。(

\(\mathscr{Code}\)

/*+Rainybunny+*/

#include <bits/stdc++.h>

#define rep(i, l, r) for (int i = l, rep##i = r; i <= rep##i; ++i)
#define per(i, r, l) for (int i = r, per##i = l; i >= per##i; --i)

const int MAXN = 3e5, MOD = 998244353, INV1E4 = 796898467;
int n, siz[MAXN + 5], ch[MAXN + 5][2], val[MAXN + 5];
int mxv, dc[MAXN + 5], root[MAXN + 5];

inline int mul(const int u, const int v) { return 1ll * u * v % MOD; }
inline void subeq(int& u, const int v) { (u -= v) < 0 && (u += MOD); }
inline int sub(int u, const int v) { return (u -= v) < 0 ? u + MOD : u; }
inline void addeq(int& u, const int v) { (u += v) >= MOD && (u -= MOD); }
inline int add(int u, const int v) { return (u += v) < MOD ? u : u - MOD; }

struct SegmentTree {
    static const int MAXND = 3e6;
    int node, ch[MAXND][2], sum[MAXND], tag[MAXND];

    inline void pushup(const int u) {
        sum[u] = add(sum[ch[u][0]], sum[ch[u][1]]);
    }

    inline void pushml(const int u, const int v) {
        assert(u);
        sum[u] = mul(sum[u], v), tag[u] = mul(tag[u], v);
    }

    inline void pushdn(const int u) {
        if (tag[u] != 1) {
            if (ch[u][0]) pushml(ch[u][0], tag[u]);
            if (ch[u][1]) pushml(ch[u][1], tag[u]);
            tag[u] = 1;
        }
    }

    inline void insert(int& u, const int l, const int r, const int x) {
        u = ++node, tag[u] = sum[u] = 1;
        if (l == r) return ;
        int mid = l + r >> 1;
        if (x <= mid) insert(ch[u][0], l, mid, x);
        else insert(ch[u][1], mid + 1, r, x);
    }

    inline void mix(const int u, const int v,
      const int su, const int sv, const int p) {
        if (!u && !v) return ;
        if (u && !v) return pushml(u, su);
        if (v && !u) return pushml(v, sv);
        if (!ch[u][0] && !ch[u][1]) return pushml(u, su), pushml(v, sv);
        pushdn(u), pushdn(v);
        int ul = sum[ch[u][0]], ur = sum[ch[u][1]];
        int vl = sum[ch[v][0]], vr = sum[ch[v][1]], q = sub(1, p);
        mix(ch[u][0], ch[v][0], add(su, mul(q, vr)), add(sv, mul(q, ur)), p);
        mix(ch[u][1], ch[v][1], add(su, mul(p, vl)), add(sv, mul(p, ul)), p);
        pushup(u), pushup(v);
    }

    inline void merge(int& u, const int v) {
        if (!u || !v) return void(u |= v);
        if (!ch[u][0] && !ch[u][1]) return addeq(sum[u], sum[v]);
        pushdn(u), pushdn(v), addeq(sum[u], sum[v]);
        merge(ch[u][0], ch[v][0]), merge(ch[u][1], ch[v][1]);
    }

    inline int answer(const int u, const int l, const int r) {
        if (l == r) return mul(mul(l, dc[l]), mul(sum[u], sum[u]));
        int mid = l + r >> 1, ret = 0; pushdn(u);
        addeq(ret, answer(ch[u][0], l, mid));
        addeq(ret, answer(ch[u][1], mid + 1, r));
        return ret;
    }
} sgt;

inline void solve(const int u) {
    if (!ch[u][0]) {
        val[u] = std::lower_bound(dc + 1, dc + mxv + 1, val[u]) - dc;
        sgt.insert(root[u], 1, mxv, val[u]);
    } else if (!ch[u][1]) {
        solve(ch[u][0]), root[u] = root[ch[u][0]];
    } else {
        solve(ch[u][0]), solve(ch[u][1]);
        sgt.mix(root[ch[u][0]], root[ch[u][1]], 0, 0, val[u]);
        sgt.merge(root[u] = root[ch[u][0]], root[ch[u][1]]);
    }
}

int main() {
    std::ios::sync_with_stdio(false), std::cin.tie(0);

    std::cin >> n;
    rep (i, 1, n) {
        int fa; std::cin >> fa;
        if (fa) ch[fa][!!ch[fa][0]] = i;
    }
    rep (i, 1, n) {
        std::cin >> val[i];
        if (ch[i][0]) val[i] = mul(val[i], INV1E4);
        else dc[++mxv] = val[i];
    }

    std::sort(dc + 1, dc + mxv + 1);
    mxv = std::unique(dc + 1, dc + mxv + 1) - dc - 1;
    solve(1);

    std::cout << sgt.answer(root[1], 1, mxv) << '\n';
    return 0;
}