聯考20200729 T3 樹上的鼠
阿新 • • 發佈:2020-07-30
分析:
首先我們從博弈入手,看看什麼情況先手必勝
先考慮一條鏈的情況,假設鏈長為偶數,即存在中點,先手就可以搶佔中點,後手無論怎麼走,先手都可以走到其關於中點的對稱點上
最後後手無法操作,先手必勝
當鏈長為奇數時,先手兩個中點隨便搶一個就可以勝利
換在樹上,求出直徑中點,非直徑上的點可以轉化為直徑上的,效果等價
那麼先手必敗當且僅當起點為連通塊直徑的唯一中點
轉化一下問題,即一個連通塊使得1不是直徑唯一中點,即1的兒子的最深深度僅有一個
考慮DP,\(f_{u,i}\)表示\(u\)為根,深度至多為\(i\)的連通塊方案數
這個暴力\(O(n^2)\)DP,發現只與最深深度有關,直接長鏈剖分優化就好了
統計答案時強行欽定某個兒子最深深度為\(D\)
複雜度\(O(n)\)
#include<cstdio> #include<cmath> #include<cstring> #include<iostream> #include<algorithm> #include<queue> #include<set> #include<map> #include<vector> #include<string> #define maxn 2000005 #define MOD 998244353 using namespace std; inline long long getint() { long long num=0,flag=1;char c; while((c=getchar())<'0'||c>'9')if(c=='-')flag=-1; while(c>='0'&&c<='9')num=num*10+c-48,c=getchar(); return num*flag; } int n; vector<int>G[maxn]; int cur,pos[maxn],son[maxn],dpt[maxn],mxd[maxn]; long long f[maxn],g[maxn],h[maxn],sum[maxn],pd[maxn]; int tmp[maxn]; inline void dfs1(int u,int fa) { dpt[u]=dpt[fa]+1,mxd[u]=dpt[u],g[u]=1; for(int i=0;i<G[u].size();i++)if(G[u][i]!=fa) { dfs1(G[u][i],u),mxd[u]=max(mxd[u],mxd[G[u][i]]); if(mxd[G[u][i]]>mxd[son[u]])son[u]=G[u][i]; g[u]=g[u]*(g[G[u][i]]+1)%MOD; } } inline void dfs2(int u,int fa) { pos[u]=++cur;if(son[u])dfs2(son[u],u); for(int i=0;i<G[u].size();i++)if(G[u][i]!=fa&&G[u][i]!=son[u])dfs2(G[u][i],u); } inline void down(int u,int lim) { if(h[u]!=1) { f[u]=f[u]*h[u]%MOD; if(u<lim)h[u+1]=h[u+1]*h[u]%MOD; h[u]=1; } } inline void solve(int u,int p) { f[pos[u]]=h[pos[u]]=sum[pos[u]]=1; if(son[u])solve(son[u],u); for(int i=0;i<G[u].size();i++)if(G[u][i]!=p&&G[u][i]!=son[u]) { int v=G[u][i];solve(v,u); for(int j=0;j<=mxd[v]-dpt[v];j++) { int k=j+1; down(pos[v]+j,pos[v]+mxd[v]-dpt[v]),sum[pos[v]+j]=((j>0)*sum[pos[v]+j-1]+f[pos[v]+j])%MOD; if(u==1)continue; down(pos[u]+k,pos[u]+mxd[u]-dpt[u]),sum[pos[u]+k]=(sum[pos[u]+k-1]+f[pos[u]+k])%MOD; f[pos[u]+k]=f[pos[u]+k]*((j>0)*sum[pos[v]+j-1]+1)%MOD+f[pos[v]+j]*sum[pos[u]+k-1]%MOD+f[pos[u]+k]*f[pos[v]+j]%MOD; f[pos[u]+k]%=MOD; } if(u==1)continue; if(mxd[v]<mxd[u])(h[pos[u]+mxd[v]-dpt[u]+1]*=(sum[pos[v]+mxd[v]-dpt[v]]+1))%=MOD; } } int main() { n=getint(); for(int i=1;i<n;i++) { int u=getint(),v=getint(); G[u].push_back(v),G[v].push_back(u); } dfs1(1,0),dfs2(1,0); solve(1,0); for(int i=2;i<=mxd[1];i++)down(i,mxd[1]),sum[i]=(i>2)*sum[i-1]+f[i]; long long ans=1; for(int i=0;i<=n;i++)pd[i]=1,h[i]=1; for(int i=0;i<G[1].size();i++) { int u=G[1][i]; for(int j=2;j<=mxd[u];j++) { if(h[j]!=1)pd[j]=pd[j]*h[j]%MOD,h[j+1]=h[j+1]*h[j]%MOD,h[j]=1; tmp[pos[u]+j-2]=pd[j-1]; } for(int j=2;j<=mxd[u];j++)pd[j]=pd[j]*(sum[pos[u]+j-2]+1)%MOD; if (mxd[u]<n)(h[mxd[u]+1]*=(sum[pos[u]+mxd[u]-2]+1))%=MOD; } for(int i=0;i<=n;i++)pd[i]=1,h[i]=1; for(int i=G[1].size()-1;~i;i--) { int u=G[1][i]; for(int j=2;j<=mxd[u];j++) { if(h[j]!=1)pd[j]=pd[j]*h[j]%MOD,h[j+1]=h[j+1]*h[j]%MOD,h[j]=1; ans+=tmp[pos[u]+j-2]*f[pos[u]+j-2]%MOD*(pd[j]-pd[j-1])%MOD; } for(int j=2;j<=mxd[u];j++)pd[j]=pd[j]*(sum[pos[u]+j-2]+1)%MOD; if(mxd[u]<n)(h[mxd[u]+1]*=(sum[pos[u]+mxd[u]-2]+1))%=MOD; } printf("%lld\n",((g[1]-ans)%MOD+MOD)%MOD); }