Solution -「PKUWC 2018」「洛谷 P5298」Minimax
阿新 • • 發佈:2022-02-23
\(\mathscr{Description}\)
Link.
給定一棵二叉樹,每片葉子有一個權值,所有權值互不相同。每個非葉結點 \(u\) 有一個概率 \(p_u\in(0,1)\),表示 \(u\) 的權值以 \(p_u\) 的概率取兒子最大權值,以 \(1-p_u\) 的概率取兒子最小權值。求根節點取到每種權值的概率(以一定形式壓縮輸出)。答案模 \(998244353\)。
\(\mathscr{Solution}\)
令 \(f(u,i)\) 表示 \(u\) 處取到全域性第 \(i\) 大的權值的概率,設 \(u\) 的左右兒子為 \(l,r\),顯然
\[f(u,i)=f(l,i)\left(p_u\sum_{j<i}f(r,j)+(1-p_u)\sum_{j>i}f(r,j)\right)+f(r,i)\left(p_u\sum_{j<i}f(l,j)+(1-p_u)\sum_{j>i}f(l,j)\right). \]這是個左右區間的交叉貢獻,用權值線段樹維護 \(f(u)\)
注意,例如左對右有貢獻,直接在右子樹上打乘法或加法之類的標記是不良的——不同時間戳下的標記合併方式不同。因而,我們可以暴力地保證,當且僅當線段樹結點 \(u\) 的後代不會被打上當前時間戳的標記(即,\(u\) 子樹內的貢獻因子都一樣了),我們才給它打乘法標記;否則僅累加係數遞迴傳遞。
所有貢獻完成後,我們再把 \(f'(l)\) 和 \(f'(r)\) 合併(對應點貢獻相加),就得到了 \(f(u)\)。本質上就是做了兩次線段樹合併,所以複雜度是 \(\mathcal O(n\log n)\)
那麼這件事情告訴我們線段樹無所不能,所有看似暴力的 DP 都拿來試一試。(
\(\mathscr{Code}\)
/*+Rainybunny+*/ #include <bits/stdc++.h> #define rep(i, l, r) for (int i = l, rep##i = r; i <= rep##i; ++i) #define per(i, r, l) for (int i = r, per##i = l; i >= per##i; --i) const int MAXN = 3e5, MOD = 998244353, INV1E4 = 796898467; int n, siz[MAXN + 5], ch[MAXN + 5][2], val[MAXN + 5]; int mxv, dc[MAXN + 5], root[MAXN + 5]; inline int mul(const int u, const int v) { return 1ll * u * v % MOD; } inline void subeq(int& u, const int v) { (u -= v) < 0 && (u += MOD); } inline int sub(int u, const int v) { return (u -= v) < 0 ? u + MOD : u; } inline void addeq(int& u, const int v) { (u += v) >= MOD && (u -= MOD); } inline int add(int u, const int v) { return (u += v) < MOD ? u : u - MOD; } struct SegmentTree { static const int MAXND = 3e6; int node, ch[MAXND][2], sum[MAXND], tag[MAXND]; inline void pushup(const int u) { sum[u] = add(sum[ch[u][0]], sum[ch[u][1]]); } inline void pushml(const int u, const int v) { assert(u); sum[u] = mul(sum[u], v), tag[u] = mul(tag[u], v); } inline void pushdn(const int u) { if (tag[u] != 1) { if (ch[u][0]) pushml(ch[u][0], tag[u]); if (ch[u][1]) pushml(ch[u][1], tag[u]); tag[u] = 1; } } inline void insert(int& u, const int l, const int r, const int x) { u = ++node, tag[u] = sum[u] = 1; if (l == r) return ; int mid = l + r >> 1; if (x <= mid) insert(ch[u][0], l, mid, x); else insert(ch[u][1], mid + 1, r, x); } inline void mix(const int u, const int v, const int su, const int sv, const int p) { if (!u && !v) return ; if (u && !v) return pushml(u, su); if (v && !u) return pushml(v, sv); if (!ch[u][0] && !ch[u][1]) return pushml(u, su), pushml(v, sv); pushdn(u), pushdn(v); int ul = sum[ch[u][0]], ur = sum[ch[u][1]]; int vl = sum[ch[v][0]], vr = sum[ch[v][1]], q = sub(1, p); mix(ch[u][0], ch[v][0], add(su, mul(q, vr)), add(sv, mul(q, ur)), p); mix(ch[u][1], ch[v][1], add(su, mul(p, vl)), add(sv, mul(p, ul)), p); pushup(u), pushup(v); } inline void merge(int& u, const int v) { if (!u || !v) return void(u |= v); if (!ch[u][0] && !ch[u][1]) return addeq(sum[u], sum[v]); pushdn(u), pushdn(v), addeq(sum[u], sum[v]); merge(ch[u][0], ch[v][0]), merge(ch[u][1], ch[v][1]); } inline int answer(const int u, const int l, const int r) { if (l == r) return mul(mul(l, dc[l]), mul(sum[u], sum[u])); int mid = l + r >> 1, ret = 0; pushdn(u); addeq(ret, answer(ch[u][0], l, mid)); addeq(ret, answer(ch[u][1], mid + 1, r)); return ret; } } sgt; inline void solve(const int u) { if (!ch[u][0]) { val[u] = std::lower_bound(dc + 1, dc + mxv + 1, val[u]) - dc; sgt.insert(root[u], 1, mxv, val[u]); } else if (!ch[u][1]) { solve(ch[u][0]), root[u] = root[ch[u][0]]; } else { solve(ch[u][0]), solve(ch[u][1]); sgt.mix(root[ch[u][0]], root[ch[u][1]], 0, 0, val[u]); sgt.merge(root[u] = root[ch[u][0]], root[ch[u][1]]); } } int main() { std::ios::sync_with_stdio(false), std::cin.tie(0); std::cin >> n; rep (i, 1, n) { int fa; std::cin >> fa; if (fa) ch[fa][!!ch[fa][0]] = i; } rep (i, 1, n) { std::cin >> val[i]; if (ch[i][0]) val[i] = mul(val[i], INV1E4); else dc[++mxv] = val[i]; } std::sort(dc + 1, dc + mxv + 1); mxv = std::unique(dc + 1, dc + mxv + 1) - dc - 1; solve(1); std::cout << sgt.answer(root[1], 1, mxv) << '\n'; return 0; }