1. 程式人生 > >[CC-BLREDSET]Black and Red vertices of Tree

[CC-BLREDSET]Black and Red vertices of Tree

[CC-BLREDSET]Black and Red vertices of Tree

題目大意:

有一棵\(n(\sum n\le10^6)\)個結點的樹,每個結點有一種顏色(紅色、黑色、白色)。刪去一個由紅色點構成的連通塊,使得存在一個黑點和一個白點,滿足這兩個點不連通。問有多少種刪法。

思路:

設滿足刪掉這個點後,使得存在一個黑點和一個白點,滿足這兩個點不連通的紅點為關鍵點。那麼我們可以用兩個\(\mathcal O(n)\)的樹形DP求出所有的關鍵點。剩下的問題就變成了求有多少種全紅連通塊使得該連通塊中至少有一個關鍵點,這顯然又可以用一個\(\mathcal O(n)\)樹形DP求出。

原始碼:

#include<cstdio>
#include<cctype>
#include<vector>
inline int getint() {
    register char ch;
    while(!isdigit(ch=getchar()));
    register int x=ch^'0';
    while(isdigit(ch=getchar())) x=(((x<<2)+x)<<1)+(ch^'0');
    return x;
}
const int N=1e5+1,mod=1e9+7;
bool mark[N];
int col[N],cnt1[N],cnt2[N],f[N][2];
std::vector<int> e[N];
inline void add_edge(const int &u,const int &v) {
    e[u].push_back(v);
    e[v].push_back(u);
}
void dfs(const int &x,const int &par) {
    cnt1[x]=cnt2[x]=0;
    if(col[x]==1) cnt1[x]=1;
    if(col[x]==2) cnt2[x]=1;
    for(unsigned i=0;i<e[x].size();i++) {
        const int &y=e[x][i];
        if(y==par) continue;
        dfs(y,x);
        cnt1[x]+=cnt1[y];
        cnt2[x]+=cnt2[y];
    }
}
void move(const int &x,const int &par) {
    bool g1=false,g2=false;
    if(x!=1) {
        g1=cnt1[par]-cnt1[x];
        g2=cnt2[par]-cnt2[x];
        cnt1[x]+=cnt1[par]-cnt1[x];
        cnt2[x]+=cnt2[par]-cnt2[x];
    }
    mark[x]=false;
    for(unsigned i=0;i<e[x].size();i++) {
        const int &y=e[x][i];
        if(y==par) continue;
        mark[x]|=cnt1[y]&&g2;
        mark[x]|=cnt2[y]&&g1;
        g1|=cnt1[y];
        g2|=cnt2[y];
        move(y,x);
    }
}
void dp(const int &x) {
    col[x]=-1;
    f[x][mark[x]]=1;
    f[x][!mark[x]]=0;
    for(unsigned i=0;i<e[x].size();i++) {
        const int &y=e[x][i];
        if(col[y]) continue;
        dp(y);
        f[x][1]=(1ll*f[x][1]*(f[y][0]+f[y][1]+1)%mod+1ll*f[x][0]*f[y][1]%mod)%mod;
        f[x][0]=1ll*f[x][0]*(f[y][0]+1)%mod;
    }
}
int main() {
    for(register int T=getint();T;T--) {
        const int n=getint();
        for(register int i=1;i<n;i++) {
            add_edge(getint(),getint());
        }
        for(register int i=1;i<=n;i++) {
            col[i]=getint();
        }
        dfs(1,0);
        move(1,0);
        for(register int i=1;i<=n;i++) {
            if(!col[i]) dp(i);
        }
        for(register int i=1;i<=n;i++) {
            e[i].clear();
        }
        int ans=0;
        for(register int i=1;i<=n;i++) {
            if(col[i]==-1) (ans+=f[i][1])%=mod;
        }
        printf("%d\n",ans);
    }
    return 0;
}