「UOJ351」新年的葉子
阿新 • • 發佈:2018-11-05
「UOJ351」新年的葉子
題目描述
有一棵大小為 \(n\) 的樹,每次隨機將一個葉子染黑,可以重複染,問期望染多少次後樹的直徑會縮小。\(1 \leq n \leq 5 \times 10^5\)
解題思路 :
首先要利用一個經典的結論,樹的所有直徑的中心為同一個點/邊。不妨給每條邊加一個虛擬點,這樣整顆樹的直徑就只會交於同一個點了。
接下來考慮樹的直徑是由中心的兩個兒子的兩個深度為 \(maxdep\) 的葉子構成的,所以問題等價於將葉子根據中心的兒子分成若干個集合,對於所有染色方案求染到只剩一個集合沒有被完全染黑的期望步數之和,這個東西再除以一個方案數就是答案。
這個東西好難求啊,推了半天式子還是不太會,最後只能看題解輔助推導 \(\text{qwq}\)
/*program by mangoyang*/ #pragma GCC optimize("Ofast","inline","-ffast-math") #pragma GCC target("avx,sse2,sse3,sse4,mmx") #include<bits/stdc++.h> #define inf (0x7f7f7f7f) #define Max(a, b) ((a) > (b) ? (a) : (b)) #define Min(a, b) ((a) < (b) ? (a) : (b)) typedef long long ll; using namespace std; template <class T> inline void read(T &x){ int ch = 0, f = 0; x = 0; for(; !isdigit(ch); ch = getchar()) if(ch == '-') f = 1; for(; isdigit(ch); ch = getchar()) x = x * 10 + ch - 48; if(f) x = -x; } #define pii pair<int, int> #define fi first #define se second const int N = 1000005, mod = 998244353; vector<int> g[N]; int dep[N], f[N], js[N], inv[N], buf[N], n, m, all, mxdep, rt; ll h[N], ans; namespace prework{ queue<pii> q; int vis[N], pre[N]; inline pii bfs(int s){ memset(vis, 0, sizeof(vis)); q.push(make_pair(s, 0)), vis[s] = 1; pii res; while(!q.empty()){ pii now = q.front(); q.pop(); int x = now.fi, dis = now.se; res = now; for(int i = 0; i < g[x].size(); i++){ int v = g[x][i]; if(vis[v]) continue; vis[v] = 1, pre[v] = x; q.push(make_pair(v, dis + 1)); } } return res; } inline void dfs(int u, int fa){ dep[u] = dep[fa] + 1, f[u] = dep[u], buf[u] = 1; for(int i = 0; i < g[u].size(); i++){ int v = g[u][i]; if(v == fa) continue; dfs(v, u); if(f[v] > f[u]) buf[u] = buf[v], f[u] = f[v]; else if(f[v] == f[u]) buf[u] += buf[v]; } if(g[u].size() == 1 && fa) m++; } inline void realmain(){ pii s1 = bfs(1), s2 = bfs(s1.fi); int dis = s2.se / 2; rt = s2.fi; for(int i = 1; i <= dis; i++) rt = pre[rt]; dfs(rt, 0); } } inline void up(ll &x, int y){ (x += y) %= mod; } inline int Pow(int a, int b){ int ans = 1; for(; b; b >>= 1, a = 1ll * a * a % mod) if(b & 1) ans = 1ll * ans * a % mod; return ans; } inline int C(int x, int y){ return 1ll * js[x] * inv[y] % mod * inv[x-y] % mod; } inline int calc(int x, int y){ ll res = 1ll * C(y, x) * C(all-(y-x)-1, x) % mod; (res *= (1ll * js[y-x] * js[all-y] % mod)) %= mod; (res *= (1ll * js[x] * (h[all] - h[y-x]) % mod)) %= mod; return res; } int main(){ js[0] = inv[0] = 1; for(int i = 1; i < N; i++) js[i] = 1ll * js[i-1] * i % mod, inv[i] = Pow(js[i], mod - 2); read(n); int size = n; for(int i = 1, x, y; i < n; i++){ read(x), read(y), ++size; g[x].push_back(size), g[size].push_back(x); g[y].push_back(size), g[size].push_back(y); } prework::realmain(); for(int i = 1; i <= m; i++) up(h[i], 1ll * (h[i-1] + 1ll * m * Pow(i, mod - 2)) % mod); for(int i = 0; i < g[rt].size(); i++) mxdep = max(mxdep, f[g[rt][i]]); for(int i = 0; i < g[rt].size(); i++){ if(f[g[rt][i]] == mxdep) all += buf[g[rt][i]]; else buf[g[rt][i]] = 0; } for(int i = 0; i < g[rt].size(); i++) for(int j = 0; j < buf[g[rt][i]]; j++) up(ans, calc(j, buf[g[rt][i]])); cout << (1ll * ans * inv[all] % mod + mod) % mod; return 0; }