[CEOI2020]星際迷航 題解
博弈+換根 dp+矩陣快速冪優化 dp
Statement
P6803 [CEOI2020]星際迷航 - 洛谷 | 電腦科學教育新生態 (luogu.com.cn)
Solution
看到 \(D\le 10^{18}\) ,知道這道題最後大概是需要一個矩陣快速冪
所以先考慮 \(D=1\) 的情況,容易發現對於一個必勝態而言,無論接一個都不會改變狀態,對於一個必敗態,只有接上一個比敗態會改變狀態
容易一遍 dfs 求出每一個節點的狀態 \(g[u]=0/1\) ,必敗/勝。令 \(lose\) 表示必敗態個數
設 \(f[u]\) 表示 \(u\) 的子樹中,有多少個點在接入一個必敗態後會使得 \(u\)
所以,若 \(g[1]=0\) ,\(ans=f[1]\times lose\) ,也就是對於那些可以把 \(1\) 扳正的位置,給他配一個 \(lose\);
若 \(g[1]=1\) ,$ans=n\times (n-lose)+lose\times(n-f[1]) $ ,現在不需要對 \(1\) 更改,所以對於 \((n-lose)\) 個必勝點,隨便接啥都可以,然後不能給那 \(f[1]\) 個點接敗點
現在考慮擴充套件一下,\(D=2\) 咋做
發現只需要把怎麼把第二第三棵樹怎麼合併一下然後當成 \(D=2\) 做
發現我們算答案其實需求的是 \(lose\) 和 \((n-lose)\)
也就是說,現在需要求出的是 \(lose\) 和 \(win\) 的方案數,設為 \(l\) 和 \(w\) ,也就是第三棵樹往第二棵樹上接,第二棵樹中敗點個數總和 和 勝點個數總和
求出後,答案依然是 \([g[1]==0] f[1]\times l+[g[1]==1]n\times w+[g[1]==1]l\times(n-f[1])\)
為了計算 \(l,w\),發現好像需要每一個點的 \(f\) ,可以藉助換根 \(dp\) 求出
所以,當 \(D=2\) 時,
\[l=\sum[g[i]=0](n-f[i])\times lose+\sum[g[i]=0]n\times(n-lose)+\sum[g[i]=1] f[i]\times lose\\ w=\sum[g[i]=1](n-f[i])\times lose+\sum[g[i]=1]n\times(n-lose)+\sum[g[i]=0]f[i]\times lose \]容易注意到其實 \(lose/win\)
最開始是 \([lose,n-lose]\)
很對,複雜度 \(O(n\log K+n)\)
換根 \(dp\) 有一點小細節
Code
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5+5;
const int mod = 1e9+7;
char buf[1<<23],*p1=buf,*p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
ll read(){
ll s=0,w=1; char ch=getchar();
while(!isdigit(ch)){if(ch=='-')w=-1; ch=getchar();}
while(isdigit(ch))s=s*10+(ch^48),ch=getchar();
return s*w;
}
struct matrix{
int a[2][2];
matrix(bool fg=false){
memset(a,0,sizeof(a));
if(fg)a[0][0]=a[1][1]=1;
}
matrix operator*(const matrix&rhs)const{
matrix res;
for(int i=0;i<2;++i)for(int j=0;j<2;++j)
res.a[i][j]=(1ll*a[i][0]*rhs.a[0][j]%mod+1ll*a[i][1]*rhs.a[1][j]%mod)%mod;
return res;
}
}p,t;
vector<int>Edge[N];
int g[N],s0[N],f[N],sf[N][2];
int g2[N],f2[N];
int n,lose;
ll k;
void dfs1(int u,int fath){
for(auto v:Edge[u])if(v^fath){
dfs1(v,u);
s0[u]+=!g[v];
sf[u][g[v]]+=f[v];
}
g[u]=s0[u]>0;
if(s0[u]==1)f[u]=sf[u][0];
else if(s0[u]==0)f[u]=sf[u][1]+1;
}
void dfs2(int u,int fath){
lose+=!g[u],f2[u]=f[u],g2[u]=g[u];
for(auto v:Edge[u])if(v^fath){
int dpu=g[u],s0u=s0[u],ru=f[u],sru0=sf[u][0],sru1=sf[u][1];
int dpv=g[v],s0v=s0[v],rv=f[v],srv0=sf[v][0],srv1=sf[v][1];
s0[u]-=!g[v];
sf[u][g[v]]-=f[v];
g[u]=s0[u]>0;
if(s0[u]==1)f[u]=sf[u][0];
else if(s0[u]==0)f[u]=sf[u][1]+1;
else f[u]=0;
s0[v]+=!g[u];
sf[v][g[u]]+=f[u];
g[v]=s0[v]>0;
if(s0[v]==1)f[v]=sf[v][0];
else if(s0[v]==0)f[v]=sf[v][1]+1;
else f[v]=0;
dfs2(v,u);
g[u]=dpu,s0[u]=s0u,f[u]=ru,sf[u][0]=sru0,sf[u][1]=sru1;
g[v]=dpv,s0[v]=s0v,f[v]=rv,sf[v][0]=srv0,sf[v][1]=srv1;
}
}
matrix ksm(matrix a,ll b){
matrix res(1);
while(b){
if(b&1)res=res*a;
a=a*a,b>>=1;
}
return res;
}
signed main(){
n=read(),k=read();
for(int i=1,u,v;i<n;++i)
u=read(),v=read(),
Edge[u].push_back(v),
Edge[v].push_back(u);
dfs1(1,0),dfs2(1,0);
p.a[0][0]=lose,p.a[0][1]=n-lose;
for(int i=1;i<=n;++i)
if(g2[i]==0)
(t.a[0][0]+=n-f2[i])%=mod,
(t.a[0][1]+=f2[i])%=mod,
(t.a[1][0]+=n)%=mod;
else (t.a[0][0]+=f2[i])%=mod,
(t.a[0][1]+=n-f2[i])%=mod,
(t.a[1][1]+=n)%=mod;
p=p*ksm(t,k-1);
if(g[1])printf("%lld\n",(1ll*(n-f2[1])*p.a[0][0]%mod+1ll*n*p.a[0][1]%mod)%mod);
else printf("%lld\n",1ll*f2[1]*p.a[0][0]%mod);
return 0;
}