1. 程式人生 > 其它 >[CEOI2020]星際迷航 題解

[CEOI2020]星際迷航 題解

博弈+換根 dp+矩陣快速冪優化 dp

Statement

P6803 [CEOI2020]星際迷航 - 洛谷 | 電腦科學教育新生態 (luogu.com.cn)

Solution

看到 \(D\le 10^{18}\) ,知道這道題最後大概是需要一個矩陣快速冪

所以先考慮 \(D=1\) 的情況,容易發現對於一個必勝態而言,無論接一個都不會改變狀態,對於一個必敗態,只有接上一個比敗態會改變狀態

容易一遍 dfs 求出每一個節點的狀態 \(g[u]=0/1\) ,必敗/勝。令 \(lose\) 表示必敗態個數

\(f[u]\) 表示 \(u\) 的子樹中,有多少個點在接入一個必敗態後會使得 \(u\)

的狀態改變

所以,若 \(g[1]=0\)\(ans=f[1]\times lose\) ,也就是對於那些可以把 \(1\) 扳正的位置,給他配一個 \(lose\)

\(g[1]=1\) ,$ans=n\times (n-lose)+lose\times(n-f[1]) $ ,現在不需要對 \(1\) 更改,所以對於 \((n-lose)\) 個必勝點,隨便接啥都可以,然後不能給那 \(f[1]\) 個點接敗點

現在考慮擴充套件一下,\(D=2\) 咋做

發現只需要把怎麼把第二第三棵樹怎麼合併一下然後當成 \(D=2\)

發現我們算答案其實需求的是 \(lose\)\((n-lose)\)

,其他的量不會因為 \(D\) 的改變而改變

也就是說,現在需要求出的是 \(lose\)\(win\) 的方案數,設為 \(l\)\(w\) ,也就是第三棵樹往第二棵樹上接,第二棵樹中敗點個數總和 和 勝點個數總和

求出後,答案依然是 \([g[1]==0] f[1]\times l+[g[1]==1]n\times w+[g[1]==1]l\times(n-f[1])\)

為了計算 \(l,w\),發現好像需要每一個點的 \(f\) ,可以藉助換根 \(dp\) 求出

所以,當 \(D=2\) 時,

\[l=\sum[g[i]=0](n-f[i])\times lose+\sum[g[i]=0]n\times(n-lose)+\sum[g[i]=1] f[i]\times lose\\ w=\sum[g[i]=1](n-f[i])\times lose+\sum[g[i]=1]n\times(n-lose)+\sum[g[i]=0]f[i]\times lose \]

容易注意到其實 \(lose/win\)

\(l/w\) 的本質是一樣的,我們把上面的轉移寫成矩陣的形式

\[\left[\begin{array}{cc|r}l^{\prime}&w^{\prime}\end{array}\right] \left[\begin{array}{cc|r}\sum[g=0](n-f)+[g=1]f&\sum[g=1](n-f)+[g=0]f\\\sum [g=0]n&\sum [g=1]n\end{array}\right]=\left[\begin{array}{cc|r}l&w\end{array}\right] \]

最開始是 \([lose,n-lose]\)

很對,複雜度 \(O(n\log K+n)\)

換根 \(dp\) 有一點小細節

Code

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
const int N = 1e5+5;
const int mod = 1e9+7;

char buf[1<<23],*p1=buf,*p2=buf;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++)
ll read(){
    ll s=0,w=1; char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')w=-1; ch=getchar();}
    while(isdigit(ch))s=s*10+(ch^48),ch=getchar();
    return s*w;
}

struct matrix{
    int a[2][2];
    matrix(bool fg=false){
        memset(a,0,sizeof(a));
        if(fg)a[0][0]=a[1][1]=1;
    }
    matrix operator*(const matrix&rhs)const{
        matrix res;
        for(int i=0;i<2;++i)for(int j=0;j<2;++j)
            res.a[i][j]=(1ll*a[i][0]*rhs.a[0][j]%mod+1ll*a[i][1]*rhs.a[1][j]%mod)%mod;
        return res;
    }
}p,t;
vector<int>Edge[N];
int g[N],s0[N],f[N],sf[N][2];
int g2[N],f2[N];
int n,lose;
ll k;

void dfs1(int u,int fath){
    for(auto v:Edge[u])if(v^fath){
        dfs1(v,u);
        s0[u]+=!g[v];
        sf[u][g[v]]+=f[v];
    }
    g[u]=s0[u]>0;
    if(s0[u]==1)f[u]=sf[u][0];
    else if(s0[u]==0)f[u]=sf[u][1]+1;
}
void dfs2(int u,int fath){
    lose+=!g[u],f2[u]=f[u],g2[u]=g[u];
    for(auto v:Edge[u])if(v^fath){
        int dpu=g[u],s0u=s0[u],ru=f[u],sru0=sf[u][0],sru1=sf[u][1];
        int dpv=g[v],s0v=s0[v],rv=f[v],srv0=sf[v][0],srv1=sf[v][1];

        s0[u]-=!g[v];
        sf[u][g[v]]-=f[v];
        g[u]=s0[u]>0;
        if(s0[u]==1)f[u]=sf[u][0];
        else if(s0[u]==0)f[u]=sf[u][1]+1;
        else f[u]=0;

        s0[v]+=!g[u];
        sf[v][g[u]]+=f[u];
        g[v]=s0[v]>0;
        if(s0[v]==1)f[v]=sf[v][0];
        else if(s0[v]==0)f[v]=sf[v][1]+1;
        else f[v]=0;
        
        dfs2(v,u);
        g[u]=dpu,s0[u]=s0u,f[u]=ru,sf[u][0]=sru0,sf[u][1]=sru1;
        g[v]=dpv,s0[v]=s0v,f[v]=rv,sf[v][0]=srv0,sf[v][1]=srv1;
    }
}
matrix ksm(matrix a,ll b){
    matrix res(1);
    while(b){
        if(b&1)res=res*a;
        a=a*a,b>>=1;
    }
    return res;
}

signed main(){
    n=read(),k=read();
    for(int i=1,u,v;i<n;++i)
        u=read(),v=read(),
        Edge[u].push_back(v),
        Edge[v].push_back(u);
    dfs1(1,0),dfs2(1,0);
    p.a[0][0]=lose,p.a[0][1]=n-lose;
    for(int i=1;i<=n;++i)
        if(g2[i]==0)
            (t.a[0][0]+=n-f2[i])%=mod,
            (t.a[0][1]+=f2[i])%=mod,
            (t.a[1][0]+=n)%=mod;
        else (t.a[0][0]+=f2[i])%=mod,
            (t.a[0][1]+=n-f2[i])%=mod,
            (t.a[1][1]+=n)%=mod;
    p=p*ksm(t,k-1);
    if(g[1])printf("%lld\n",(1ll*(n-f2[1])*p.a[0][0]%mod+1ll*n*p.a[0][1]%mod)%mod);
    else printf("%lld\n",1ll*f2[1]*p.a[0][0]%mod);
    return 0;
}