題解-PKUWC2018 Minimax
阿新 • • 發佈:2018-12-26
Problem
Solution
pkuwc2018最水的一題,要死要活調了一個多小時(1h59min)
我寫這題不是因為它有多好,而是為了保持pkuwc2018的隊形,與這題類似的有一個好玩得多的題
由於答案的形式和期望相去甚遠,所以可以肯定這題和期望無關,而且這麼複雜的式子我們最好還是把所有可能計算出來啦
可以肯定地說這題是從葉子向根合併,合併時分類討論下取最大\((p)\)還是最小\((1-p)\),然後維護前後綴概率和即可
再瞟一眼資料,發現線段樹合併可以解決,完結
Code
注意由於線段樹合併時若一個節點為空則直接返回,但這棵子樹需要乘上整體概率,所以還要在裸的線段樹合併中加上標記維護
#include <bits/stdc++.h> using namespace std; typedef long long ll; inline void read(int&x){ char c11=getchar();x=0;while(!isdigit(c11))c11=getchar(); while(isdigit(c11))x=x*10+c11-'0',c11=getchar(); } const int N=301050,M=N*10,p=998244353; struct Edge{int v,nxt;}a[N]; int ls[M],rs[M],f[M],lz[M]; int head[N],c[N],b[N],rt[N]; int n,_,tot; template <typename _tp> inline _tp qm(_tp x){return x<0?x+p:x<p?x:x-p;} void update(int l,int r,int&x,int vl){ if(!x)lz[x=++tot]=1; if(l==r){f[x]=1;return ;} int mid(l+r>>1); if(vl<=mid)update(l,mid,ls[x],vl); else update(mid+1,r,rs[x],vl); f[x]=qm(f[ls[x]]+f[rs[x]]); } inline void mul(int x,int vl){ f[x]=1ll*f[x]*vl%p; lz[x]=1ll*lz[x]*vl%p; } inline void down(int x){ if(ls[x])mul(ls[x],lz[x]); if(rs[x])mul(rs[x],lz[x]); lz[x]=1; } int merge(int x,int y,ll p0,ll p1,ll p2){ if(x&&lz[x]!=1)down(x); if(y&&lz[y]!=1)down(y); if(!x){mul(y,p0*p1%p+qm(1-p0)*qm(1-p1)%p);return y;} if(!y){mul(x,p0*p2%p+qm(1-p0)*qm(1-p2)%p);return x;} rs[x]=merge(rs[x],rs[y],p0,qm(p1+f[ls[x]]),qm(p2+f[ls[y]])); ls[x]=merge(ls[x],ls[y],p0,p1,p2); f[x]=qm(f[ls[x]]+f[rs[x]]); return x; } void dfs(int x){ for(int i=head[x];i;i=a[i].nxt){ dfs(a[i].v); if(!rt[x])rt[x]=rt[a[i].v]; else rt[x]=merge(rt[x],rt[a[i].v],c[x],0ll,0ll); } if(!head[x])update(1,n,rt[x],c[x]); } int calc(int l,int r,int x){ if(lz[x]!=1)down(x); if(l==r)return 1ll*l *b[l]%p *f[x]%p *f[x]%p; int mid(l+r>>1); return qm(calc(l,mid,ls[x])+calc(mid+1,r,rs[x])); } int main(){ read(n); for(int i=1,x;i<=n;++i){ read(x),a[++_].v=i; a[_].nxt=head[x],head[x]=_; } int tt=0;const ll inv5=796898467; for(int i=1;i<=n;++i){ read(c[i]); if(head[i])c[i]=inv5*c[i]%p; else b[++tt]=c[i]; } sort(b+1,b+tt+1); int end=unique(b+1,b+tt+1)-b; for(int i=1;i<=n;++i) if(!head[i])c[i]=lower_bound(b+1,b+end,c[i])-b; n=end-1;dfs(1); printf("%d\n",calc(1,n,1)); return 0; }