牛客國慶集訓派對Day3 B-Tree (樹上包含某個節點的連通子集個數)
阿新 • • 發佈:2018-12-13
題目描述
修修去年種下了一棵樹,現在它已經有n個結點了。
修修非常擅長數數,他很快就數出了包含每個點的連通點集的數量。
瀾瀾也想知道答案,但他不會數數,於是他把問題交給了你。
輸入描述:
第一行一個整數n (1≤ n ≤ 106),接下來n-1行每行兩個整數ai,bi表示一條邊 (1≤ ai,bi≤ n)。
輸出描述:
輸出n行,每行一個非負整數。第i行表示包含第i個點的連通點集的數量對109+7取模的結果。
示例1
輸入
6 1 2 1 3 2 4 4 5 4 6
輸出
12 15 7 16 9 9
解題思路:對於一棵樹,不妨設以1為根。那麼我們可以通過一個樹上dp,很容易的就能求出包含根的樹上連通點集個數。方程如下。這樣子就計算出了以u為根的子樹中包含u的連通點集個數。
dp[u]=dp[u]*(dp[v]+1);
但是這樣子只有根的答案是正確的,其他節點的答案都不正確。
假設我們已經知道了一個節點往上走的子樹的答案ans,那麼我們就可以很容易的計算出當前子樹的正確答案
dp[u]=(ans+1)*dp[u];
那麼我們怎麼算出網上走的子樹的答案呢?我們在深搜的時候,暴力計算其他孩子對答案的貢獻即可,但是這樣複雜度最壞是O(N*sqrt(N))的,所以要優化,實際上對於往上走的答案,用它父親的答案除以dp[u]+1即可算出,但是這裡要用逆元處理,會出現inv(MOD)%MOD==0的情況,這裡用暴力算就好了。
#include<iostream> #include<algorithm> #include<math.h> #include<queue> #include<string> #include<vector> #include<bitset> using namespace std; typedef long long ll; const int MAXN=1000006; const ll MOD=1e9+7; ll pow_mod(ll a, ll k) { ll rst = 1; while (k) { if (k&1) rst = rst * a % MOD; a = a * a % MOD; k >>= 1; } return rst; } inline ll inv(ll x) { return pow_mod(x, MOD - 2)%MOD; } inline void scan_d(int &ret) { char c; ret = 0; while ((c = getchar()) < '0' || c > '9'); while (c >= '0' && c <= '9') { ret = ret * 10 + (c - '0'), c = getchar(); } } void Out(ll a) { // 輸出外掛 if (a < 0) { putchar('-'); a = -a; } if (a >= 10) { Out(a / 10); } putchar(a % 10 + '0'); } vector<int> G[MAXN]; ll dp[MAXN];//儲存所有子樹的的包含子樹的根的連通點集個數 int pre[MAXN]; ll sta[MAXN];//儲存不包含當前子樹的包含子樹的根的父親的連通點集個數。 ll ans[MAXN]; int N; void dfs1(int u,int fa){ pre[u]=fa; for(int i=0;i<G[u].size();i++){ if(G[u][i]!=fa){ dfs1(G[u][i],u); dp[u]=dp[u]*(dp[G[u][i]]+1)%MOD; } } } void dfs2(int u,int fa){ if(fa!=-1){ if((dp[u]+1)%MOD==0)//特殊情況,暴力處理 { ll num=sta[fa]+1; for(int i=0;i<G[fa].size();i++){ int v=G[fa][i]; if(v==pre[fa]||v==u) continue; num=num*(dp[v]+1)%MOD; } sta[u]=num; ans[u]=(num+1)*dp[u]%MOD; } else//否則用逆元直接計算。 { ll num=ans[fa]*inv(dp[u]+1)%MOD; sta[u]=num; ans[u]=(num+1)*dp[u]%MOD; } } for(int i=0;i<G[u].size();i++){ if(G[u][i]!=fa){ dfs2(G[u][i],u); } } } int main(){ scan_d(N); int u,v; for(int i=1;i<N;i++){ dp[i]=1; scan_d(u); scan_d(v); G[u].emplace_back(v); G[v].emplace_back(u); } dp[N]=1; dfs1(1,-1); ans[1]=dp[1]; dfs2(1,-1); for(int i=1;i<=N;i++){ Out(ans[i]); puts(""); } return 0; }