牛客多校第3場 B Expected Number of Nodes 題解
阿新 • • 發佈:2019-01-31
題意: 就是給你一個壓縮圖的方式,然後讓你計算任意選擇k(1….n)個點之後的圖上剩餘點的期望數(mod 1e9+7).
做法:
很顯然的得到兩個結論(記當前要保留的點為k,總點數為n).
1.對於一個度數小於等於2的點,他的留下來的概率為(n-1,k-1)/(n,k). 因為只有這個點被選到了,才有可能留下來.
2.對於一個度數大於2的點,他留下來的有兩種情況,一種是他被選到了, 還有就是他有3個及以上的兒子所在的子樹中有點被選到.那麼可以得出其概率為
1 - ( \sum_{i<j}(si_i + si_j, k) + (m-2) * \sum_i(si_i,k) ) / (N, k).
其中m為當前點兒子的個數,其兒子所在的子樹的大小分別為[si_1, si_2, si_3…., si_m].
我們所求的期望數即為所有點保留的概率的和.
對於第1類的點我們可以很方便的維護,直接加入到答案中或者統計他的數量最後一起加.
對於第2類的點,難以維護的就是\sum_{i<j}(si_i + si_j, k)
,其他的數值都是可以直接加到答案上的. 首先我們要考慮到 si_i + si_j 的最大值為n,那麼我們就可以用cnt[]去維護,cnt[x]表示si_i+si_j = x的數量, 那麼我們就可以暴力的維護cnt[] (O(n2))
然後對於每一個k,統一加入到答案中(O(n2)).
這樣我們就可以算出所有的答案了.
具體的過程見程式碼.
#include<bits/stdc++.h>
#define fi first
#define se second
#define lson l,mid,o<<1
#define rson mid+1,r,o<<1|1
#define fio ios::sync_with_stdio(false);cin.tie(0)
using namespace std;
typedef long long LL;
typedef unsigned long long uLL;
typedef pair<int, int> PII;
typedef pair<int , int> P;
typedef pair<PII, int> PIII;
const LL INF = 0x3f3f3f3f;
const int N = 5e3 + 10;
const int M = 11;
const LL mod = 1e9 + 7;
const double PI=acos(-1);
inline int ab(int x){return x < 0 ? -x : x;}
inline LL mm(LL x){return x >= mod ? x - mod : x < 0 ? x + mod : x;}
int F[N], Finv[N], inv[N];
void init(){
inv[1] = 1;
for(int i = 2; i < N; i ++){
inv[i] = (mod - mod / i) * 1ll * inv[mod % i] % mod;
}
F[0] = Finv[0] = 1;
for(int i = 1; i < N; i ++){
F[i] = F[i-1] * 1ll * i % mod;
Finv[i] = Finv[i-1] * 1ll * inv[i] % mod;
}
}
int comb(int n, int m){
if(m < 0 || m > n) return 0;
return F[n] * 1ll * Finv[n - m] % mod * Finv[m] % mod;
}
int invcomb(int n, int m){
if(m < 0 || m > n) return 0;
return Finv[n] * 1ll * F[n - m] % mod * F[m] % mod;
}
vector<int>son[N];
int si[N];
int n;
int cnt[N], tot = 0;
int ans[N];
void dfs(int o, int fa){
si[o] = 1;
LL sum[N];
for(int i = 0; i <= n; ++i) sum[i] = 0;
for(auto it : son[o]){
if(it == fa) continue;
dfs(it, o);
si[o] += si[it];
}
if(son[o].size() <= 2) tot++;
else{
for(int i = 0; i < son[o].size(); ++i){
if(son[o][i] == fa) continue;
for(int j = 0; j < i; ++j){
if(son[o][j] == fa) continue;
cnt[si[son[o][i]] + si[son[o][j]]]++;
}
for(int k = 1; k <= n; ++k){
ans[k] += (son[o].size() - 2) * 1ll * comb(si[son[o][i]], k) % mod;
if(ans[k] >= mod) ans[k] -= mod;
}
}
if(o != 1){
for(int j = 0; j < son[o].size(); ++j){
if(son[o][j] == fa) continue;
cnt[n - si[o] + si[son[o][j]]]++;
}
for(int k = 1; k <= n; ++k){
ans[k] += (son[o].size() - 2) * 1ll * comb(n - si[o], k) % mod;
if(ans[k] >= mod) ans[k] -= mod;
}
}
}
}
int main()
{
init();
int u, v;
scanf("%d", &n);
for(int i = 1; i < n; ++i){
scanf("%d%d", &u, &v);
son[u].push_back(v);
son[v].push_back(u);
}
dfs(1, 1);
for(int k = 1; k <= n; ++k){
ans[k] += tot * 1ll * comb(n - 1, k -1) % mod;
if(ans[k] >= mod) ans[k] -= mod;
for(int i = k; i <= n; ++i){
ans[k] -= cnt[i] * 1ll * comb(i, k) % mod;
if(ans[k] < 0) ans[k] += mod;
}
ans[k] = ans[k] * 1ll * invcomb(n, k) % mod;
ans[k] += n - tot;
if(ans[k] >= mod) ans[k] -= mod;
printf("%d\n", ans[k]);
}
return 0;
}