1. 程式人生 > 實用技巧 >loj2537. 「PKUWC2018」Minimax

loj2537. 「PKUWC2018」Minimax

題面

題意:自己看去

題解:先考慮一個暴力的樹形dp。設\(f_{i,j}\)表示節點\(i\)權值為\(j\)的概率。那麼對於所有有兩個兒子的節點\(i\),設它的兩個兒子是\(x,y\),那麼對於所有在\(x\)中出現的權值\(j\),有\(f_{i,j}=f_{x,j}\times ((p_i \times \sum_{k=1}^{j-1}f_{y,k}) +((1-p_i)\times \sum_{k=j+1}^\infty f_{y,k}))\)。對於\(y\)中出現的權值也一樣。由於轉移方程中出現了區間和,考慮用線段樹來維護這個區間和。

考慮線上段樹合併的同時進行dp。考慮當前合併的兩個節點是\(x,y\)

,區間是\([l,r]\),二者的兒子分別為\(lc_x,rc_x,lc_y,rc_y\)。那麼對於當前節點,\(lc_x,rc_y\)\(rc_x,lc_y\)都會互相造成貢獻。線上段樹合併時記錄一下這些貢獻的和,然後遞迴到\(x,y\)的兒子進行合併。當\(x=0\)\(y=0\)時,對於一段區間的\(f\)造成的貢獻就累加完了,打一個乘法標記即可。

時間複雜度:\(O(nlogn)\)。如果不離散化的話,時間複雜度就是\(O(nlogV)\)

程式碼:

#include<bits/stdc++.h>
using namespace std;
#define re register int
#define F(x,y,z) for(re x=y;x<=z;x++)
#define FOR(x,y,z) for(re x=y;x>=z;x--)
typedef long long ll;
#define I inline void
#define IN inline int
#define C(x,y) memset(x,y,sizeof(x))
#define STS system("pause")
template<class D>I read(D &res){
	res=0;register D g=1;register char ch=getchar();
	while(!isdigit(ch)){
		if(ch=='-')g=-1;
		ch=getchar();
	}
	while(isdigit(ch)){
		res=(res<<3)+(res<<1)+(ch^48);
		ch=getchar();
	}
	res*=g;
}
const int Mod=998244353,inv=796898467;
typedef pair<int,int>pii;
pii p[303000];
vector<int>e[303000];
int n,m,now,ans,fa[303000],a[303000],xu[303000],b[303000];
int tot,d[303000],root[303000],lc[10100000],rc[10100000],w[10100000],tag[10100000];
I add(int &x,int y){(x+=y)>=Mod?x-=Mod:0;}
IN Plus(int x,int y){(x+=y)>=Mod?x-=Mod:0;return x;}
IN Pow(int x,int y=Mod-2){
	re res=1;
	while(y){
		if(y&1)res=(ll)res*x%Mod;
		x=(ll)x*x%Mod;
		y>>=1;
	}
	return res;
}
I D_2(int x){
	if(e[x].empty())return p[++m]=make_pair(a[x],x),void();
	for(auto d:e[x])D_2(d);
}
#define lt lc[k],l,mid
#define rt rc[k],mid+1,r
I mul(int x,int v){
	if(!x)return;
	w[x]=(ll)w[x]*v%Mod;tag[x]=(ll)tag[x]*v%Mod;
}
I push_down(int x){
	mul(lc[x],tag[x]);mul(rc[x],tag[x]);tag[x]=1;
}
I modi(int &k,int l,int r,int x){
	k=++tot;w[k]=tag[k]=1;
	if(l==r)return;
	re mid=(l+r)>>1;
	if(x<=mid)modi(lt,x);
	else modi(rt,x);
}
IN merge(int x,int y,int vx,int vy){
	if(!x||!y){
//		cout<<"!"<<vx<<" "<<vy<<endl;
		if(y)mul(y,vy);
		if(x)mul(x,vx);	
		return x+y;
	}
	push_down(x);push_down(y);
	re val[2][2];val[0][0]=w[lc[x]];val[0][1]=w[rc[x]];val[1][0]=w[lc[y]];val[1][1]=w[rc[y]];
	lc[x]=merge(lc[x],lc[y],(vx+(ll)val[1][1]*(Mod+1-now)%Mod)%Mod,(vy+(ll)val[0][1]*(Mod+1-now)%Mod)%Mod);
	rc[x]=merge(rc[x],rc[y],(vx+(ll)val[1][0]*now%Mod)%Mod,(vy+(ll)val[0][0]*now%Mod)%Mod);
	w[x]=Plus(w[lc[x]],w[rc[x]]);
	return x;
}
I damage(int k,int l,int r){
	if(l==r)return d[l]=w[k],void();
	push_down(k);
	re mid=(l+r)>>1;
	damage(lt);damage(rt);
}
I D_1(int x){
	if(e[x].empty()){
		modi(root[x],1,m,b[x]);
//		cout<<"A"<<b[x]<<endl;
		return;
	}
	root[x]=0;
	for(auto d:e[x]){
		D_1(d);
		if(!root[x])root[x]=root[d];
		else now=(ll)a[x]*inv%Mod,root[x]=merge(root[x],root[d],0,0);
	}
//	assert(!w[0]);
}
int main(){
	read(n);
	F(i,1,n)read(fa[i]),e[fa[i]].emplace_back(i);
	F(i,1,n)read(a[i]);
	D_2(1);sort(p+1,p+1+m);F(i,1,m)b[p[i].second]=i;
	D_1(1);damage(root[1],1,m);
	F(i,1,m)add(ans,(ll)i*p[i].first%Mod*d[i]%Mod*d[i]%Mod);
	printf("%d",ans);
	return 0;
}
/*
3
0 1 1
5000 1 2
*/