1. 程式人生 > 實用技巧 >【題解】Loj #3340. 「NOI2020」命運

【題解】Loj #3340. 「NOI2020」命運

顯然更好的閱讀體驗

考慮容斥,列舉哪些路徑一定不合法即可:\(\sum (-1)^{|S|}2^{(n-1)-val_S}\) 。其中 \(val_S\)\(S\) 中的路徑覆蓋的邊數條數。

因為所有的路徑的兩個端點滿足祖先關係,考慮用 \(dp\) 解決這個問題。

\(dp_{u,i}\) 表示對於所有起點在 \(u\) 的子樹中的路徑,選了一些,滿足選出來的這些路徑的終點的最小深度為 \(i\) 時的容斥係數和。

這裡考慮將 \(\sum (-1)^{|S|}2^{(n-1)-val_S}\) 拆為 \(2^{n-1}\sum (-1)^{|S|}\frac{1}{2^{val_S}}\)

,然後所謂的容斥係數指的 \((-1)^{|S|}\frac{1}{2^{val_S}}\) ,這樣就可以很方便的合併了。

直接這樣做時間複雜度 \(O(n^2)\) ,可以拿到 \(36\ pts\)

void solve(int u,int fa) {
    if(~top[u]) pls(dp[u][top[u]],mod-inv[dep[u]-top[u]]);
    pls(dp[u][dep[u]],1);

    for(int v:son[u]) if(v!=fa) {
        solve(v,u);
        lep(i,1,dep[u]) lep(j,1,dep[v]) {
            int res=mul(dp[u][i],dp[v][j]);
            if(j<dep[v]) res=mul(res,fac[dep[u]-max(i,j)]);
            pls(tmp[min(i,j)],res);
        }
        lep(i,0,n) dp[u][i]=tmp[i],tmp[i]=0;
    }
}

答案顯然就是 \(dp_{1,1}\)

(注意到可能有很多條起點為 \(u\) 的路徑,但是隻需要保留最下面的一條(即終點深度最神的),因為乘上容斥係數後上面的全部抵消了。

考慮優化。

考慮另外一種更加整潔的 \(O(n^2)\) 做法,即轉移的時候列舉目標位置,然後中間這個 if(j<dep[v]) 太醜了,去掉後在最後單獨處理 \(dp_{u,dep_u}\) 即可。這個 fac[dep[u]-max(i,j)] 也可以拆開。

void solve(int u,int fa) {
    if(~top[u]) pls(dp[u][top[u]],mod-inv[dep[u]-top[u]]);
    pls(dp[u][dep[u]],1);

    for(int v:son[u]) if(v!=fa) {
        solve(v,u);

        lep(i,1,dep[u]) {
            int res=mul(mul(dp[u][i],dp[v][i]),inv[i]);
            lep(j,i+1,dep[u]) pls(res,mul(mul(dp[u][j],dp[v][i]),inv[j]));
            lep(j,i+1,dep[v]) pls(res,mul(mul(dp[u][i],dp[v][j]),inv[j]));
            dp[u][i]=res;
        }
        lep(i,0,n) dp[u][i]=mul(fac[dep[u]],dp[u][i]);
    }
    if(u!=1) dp[u][dep[u]]=mul(dp[u][dep[u]],2);
}

接下來,不妨將 \(dp_{u,i}\) 重定義為 \(dp_{u,i}\cdot \frac{1}{2^i}\) 。那麼轉移的時候就可以去掉後面的 inv[*]

然後就可以發現,\(i\) 這個位置,轉移就是三種:\(u,v\) 都取第 \(i\) 位置,\(u\) 取第 \(i\) 位且 \(v\) 取一個字尾和,\(v\) 取第 \(i\) 位且 \(u\) 取一個字尾和。

PKUWC 2018 Minimax 啟發了我們線段樹合併也可以優化一類有關字首/字尾和轉移的 \(dp\) 。不妨考慮用線段樹合併來解決此題,然後就做完了。

合併過程可以見 merge 函式。

const int N=5e5+5;
int n,m,inv[N],fac[N],dep[N],top[N];
std::vector <int> son[N];

inline void init() {
    fac[0]=inv[0]=1;
    int inv2=modpow(2,-1);
    lep(i,1,n) inv[i]=mul(inv[i-1],inv2),fac[i]=mul(fac[i-1],2);
}

// {{{ Segment_Tree

#define mid ((l+r)>>1)

const int M=2e7+5;
int cnt,rt[N],lc[M],rc[M],sum[M],tag[M];

inline void update(int x,int v) {sum[x]=mul(sum[x],v),tag[x]=mul(tag[x],v);}
inline void pushdown(int x) {if(tag[x]!=1) update(lc[x],tag[x]),update(rc[x],tag[x]),tag[x]=1;}

void modify(int &x,int l,int r,int pos,int val) {
    if(!x) x=++cnt,tag[x]=1; pls(sum[x],val);
    if(l==r) return void();
    pushdown(x);
    pos<=mid?modify(lc[x],l,mid,pos,val):modify(rc[x],mid+1,r,pos,val);
}
void _ex_modify(int &x,int l,int r,int pos) {
    if(!x) x=++cnt,tag[x]=1;
    if(l==r) return sum[x]=mul(sum[x],2),void();
    pushdown(x);
    pos<=mid?_ex_modify(lc[x],l,mid,pos):_ex_modify(rc[x],mid+1,r,pos);
    sum[x]=add(sum[lc[x]],sum[rc[x]]);
}
int query(int x,int l,int r,int pos) {
    if(l==r) return sum[x];
    pushdown(x);
    return (pos<=mid?query(lc[x],l,mid,pos):query(rc[x],mid+1,r,pos));
}

int merge(int x,int y,int l,int r,int s1,int s2) {
    if(!x||!y) {
        if(!x) return update(y,s1),y;
        if(!y) return update(x,s2),x;
    }

    if(l==r) {
        int now=add(add(mul(sum[x],s2),mul(sum[y],s1)),mul(sum[x],sum[y]));
        if(!x) x=y;
        return sum[x]=now,x;
    }
    pushdown(x),pushdown(y);
    lc[x]=merge(lc[x],lc[y],l,mid,add(s1,sum[rc[x]]),add(s2,sum[rc[y]]));
    rc[x]=merge(rc[x],rc[y],mid+1,r,s1,s2);
    sum[x]=add(sum[lc[x]],sum[rc[x]]);
    return x;
}

// }}

void dfs(int u,int fa) {
    dep[u]=dep[fa]+1;
    for(int v:son[u]) if(v!=fa) dfs(v,u);
}
void solve(int u,int fa) {
    if(~top[u]) modify(rt[u],1,n,top[u],mul(mod-inv[dep[u]-top[u]],inv[top[u]]));
    modify(rt[u],1,n,dep[u],inv[dep[u]]);

    for(int v:son[u]) if(v!=fa) {
        solve(v,u);
        rt[u]=merge(rt[u],rt[v],1,n,0,0);
        update(rt[u],fac[dep[u]]);
    }
    if(u!=1) _ex_modify(rt[u],1,n,dep[u]);
}

int u,v;
int main() {
    freopen("destiny.in","r",stdin);
    freopen("destiny.out","w",stdout);

    IN(n);
    lep(i,2,n) IN(u,v),son[u].pb(v),son[v].pb(u);
    dfs(1,0);

    IN(m);
    lep(i,1,n) top[i]=-1;
    lep(i,1,m) IN(u,v),chkmax(top[v],dep[u]);
    init(),solve(1,0);
    printf("%d\n",mul(fac[n],query(rt[1],1,n,1)));
    return 0;
}