1. 程式人生 > >牛客多校第3場 B Expected Number of Nodes 題解

牛客多校第3場 B Expected Number of Nodes 題解

題意: 就是給你一個壓縮圖的方式,然後讓你計算任意選擇k(1….n)個點之後的圖上剩餘點的期望數(mod 1e9+7).

做法:
很顯然的得到兩個結論(記當前要保留的點為k,總點數為n).
1.對於一個度數小於等於2的點,他的留下來的概率為(n-1,k-1)/(n,k). 因為只有這個點被選到了,才有可能留下來.
2.對於一個度數大於2的點,他留下來的有兩種情況,一種是他被選到了, 還有就是他有3個及以上的兒子所在的子樹中有點被選到.那麼可以得出其概率為
1 - ( \sum_{i<j}(si_i + si_j, k) + (m-2) * \sum_i(si_i,k) ) / (N, k).


其中m為當前點兒子的個數,其兒子所在的子樹的大小分別為[si_1, si_2, si_3…., si_m].

我們所求的期望數即為所有點保留的概率的和.
對於第1類的點我們可以很方便的維護,直接加入到答案中或者統計他的數量最後一起加.
對於第2類的點,難以維護的就是\sum_{i<j}(si_i + si_j, k),其他的數值都是可以直接加到答案上的. 首先我們要考慮到 si_i + si_j 的最大值為n,那麼我們就可以用cnt[]去維護,cnt[x]表示si_i+si_j = x的數量, 那麼我們就可以暴力的維護cnt[] (O(n2))
然後對於每一個k,統一加入到答案中(O(n2)).
這樣我們就可以算出所有的答案了.
具體的過程見程式碼.

#include<bits/stdc++.h>
#define fi first
#define se second
#define lson l,mid,o<<1
#define rson mid+1,r,o<<1|1
#define fio ios::sync_with_stdio(false);cin.tie(0)
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
typedef pair<int, int> PII;
typedef pair<int
, int> P; typedef pair<PII, int> PIII; const LL INF = 0x3f3f3f3f; const int N = 5e3 + 10; const int M = 11; const LL mod = 1e9 + 7; const double PI=acos(-1); inline int ab(int x){return x < 0 ? -x : x;} inline LL mm(LL x){return x >= mod ? x - mod : x < 0 ? x + mod : x;} int F[N], Finv[N], inv[N]; void init(){ inv[1] = 1; for(int i = 2; i < N; i ++){ inv[i] = (mod - mod / i) * 1ll * inv[mod % i] % mod; } F[0] = Finv[0] = 1; for(int i = 1; i < N; i ++){ F[i] = F[i-1] * 1ll * i % mod; Finv[i] = Finv[i-1] * 1ll * inv[i] % mod; } } int comb(int n, int m){ if(m < 0 || m > n) return 0; return F[n] * 1ll * Finv[n - m] % mod * Finv[m] % mod; } int invcomb(int n, int m){ if(m < 0 || m > n) return 0; return Finv[n] * 1ll * F[n - m] % mod * F[m] % mod; } vector<int>son[N]; int si[N]; int n; int cnt[N], tot = 0; int ans[N]; void dfs(int o, int fa){ si[o] = 1; LL sum[N]; for(int i = 0; i <= n; ++i) sum[i] = 0; for(auto it : son[o]){ if(it == fa) continue; dfs(it, o); si[o] += si[it]; } if(son[o].size() <= 2) tot++; else{ for(int i = 0; i < son[o].size(); ++i){ if(son[o][i] == fa) continue; for(int j = 0; j < i; ++j){ if(son[o][j] == fa) continue; cnt[si[son[o][i]] + si[son[o][j]]]++; } for(int k = 1; k <= n; ++k){ ans[k] += (son[o].size() - 2) * 1ll * comb(si[son[o][i]], k) % mod; if(ans[k] >= mod) ans[k] -= mod; } } if(o != 1){ for(int j = 0; j < son[o].size(); ++j){ if(son[o][j] == fa) continue; cnt[n - si[o] + si[son[o][j]]]++; } for(int k = 1; k <= n; ++k){ ans[k] += (son[o].size() - 2) * 1ll * comb(n - si[o], k) % mod; if(ans[k] >= mod) ans[k] -= mod; } } } } int main() { init(); int u, v; scanf("%d", &n); for(int i = 1; i < n; ++i){ scanf("%d%d", &u, &v); son[u].push_back(v); son[v].push_back(u); } dfs(1, 1); for(int k = 1; k <= n; ++k){ ans[k] += tot * 1ll * comb(n - 1, k -1) % mod; if(ans[k] >= mod) ans[k] -= mod; for(int i = k; i <= n; ++i){ ans[k] -= cnt[i] * 1ll * comb(i, k) % mod; if(ans[k] < 0) ans[k] += mod; } ans[k] = ans[k] * 1ll * invcomb(n, k) % mod; ans[k] += n - tot; if(ans[k] >= mod) ans[k] -= mod; printf("%d\n", ans[k]); } return 0; }