[清華集訓2017]生成樹計數
[清華集訓2017]生成樹計數
題面
題解
考慮貢獻 \(\mathrm{val}(T) = \left(\prod_{i=1}^{n} {d_i}^m\right)\left(\sum_{i=1}^{n} {d_i}^m\right)\),我們先不管後面 \(\sum_{i=1}^nd_i^m\) 的部分。
然後我們就搬來 prufer 序列聯通塊生成樹理論的那一套,設生成樹中每個聯通塊的度數為\(d_i\),那麼貢獻可以表示為
\[\sum_{\sum d_i=2n-2,d_i\geq 1}\frac {(n-2)!}{\prod(d_i-1)!}\prod a_i^{d_i}\prod d_i^m \]其中\(\frac {(n-2)!}{\prod(d_i-1)!}\)是聯通塊構成的 prufer 序列數,\(\prod a_i^{d_i}\)是每個聯通塊選擇點的方案數,\(\prod d_i^m\)是我們現在只考慮的貢獻。
轉化一下就是:
構造 EGF :
\[F(x)=\sum_{i=0}^{\infty} \frac {x^i}{i!}(i+1)^m \]那麼最後我們的答案就是
\[\prod a_i (n-2)![x^{n-2}]\prod F(a_ix) \]現在問題就變為了如何求\(\prod F(a_ix)\)。
考慮這樣一個問題:給定一個\(m\)次多項式\(B(x)=\sum_{i=0}^m b_ix^i\)和\(n\)個數\(a_i\),如何求\(\sum B(a_ix)\)。
把和寫開就是\(\sum B(a_ix)=\sum_i\sum_jb_ja_i^jx^j=\sum_jx^jb_j\sum_ia_i^j\),然後就是對於每個\(j\in[0,m]\),\(a\)的等冪和。
等冪和可以表示為\(\sum_i\frac 1{1-a_ix}\)的每一項,通分後就是\(\frac {\sum_i \prod_{j\neq i} (1-a_jx)}{\prod (1-a_ix)}\)
記\(C(x)=\prod (1-a_ix)\),那麼\(C\)的係數翻轉之後的多項式\(C_R(x)=\prod (-a_i+x)\),求導後\(C_R'(x)=\sum_i1\times\prod _{j\neq i}(-a_j+x)\),最後\(\big (C'_R(x)\big )_R\)就是分子,用分治 FFT 和多項式求逆可以做到\(O(n\log ^2n+m\log m)\)
回到求\(\prod F(a_ix)\),\(\prod F(a_ix)=\exp(\sum\ln F(a_ix))=\exp (\sum(\ln F)(a_ix))\),然後就是上面求的等冪和了。
最後再考慮加上 \(\sum_id_i^m\) 的部分。
發現 \(\mathrm{val}(T) = \left(\prod_{i=1}^{n} {d_i}^m\right)\left(\sum_{i=1}^{n} {d_i}^m\right)\) 就是欽定某個\(i\)的貢獻為\(d_i^{2m}\),令\(G(x)=\sum_{i=0}^{\infty} \frac {x^i}{i!}(i+1)^{2m}\),那麼答案的生成函式可表示為
\[\left (\prod F(a_ix)\right )\left (\sum \frac {G(a_ix)}{F(a_ix)}\right) \]求出\(H(x)=\frac {G(x)}{F(x)}\)後再做一遍等冪和即可。
最後複雜度是\(O(n\log ^2n+n\log m)\),複雜度與所給\(m\)基本無關,但是常數的話你懂的。。
程式碼
#include <bits/stdc++.h>
using namespace std;
int gi() {
int res = 0, w = 1;
char ch = getchar();
while (ch != '-' && !isdigit(ch)) ch = getchar();
if (ch == '-') w = -1, ch = getchar();
while (isdigit(ch)) res = res * 10 + ch - '0', ch = getchar();
return res * w;
}
const int Mod = 998244353;
int fpow(int x, int y) {
int res = 1;
while (y) {
if (y & 1) res = 1ll * res * x % Mod;
x = 1ll * x * x % Mod, y >>= 1;
}
return res;
}
const int MAX_N = 2e5 + 5;
int fac[MAX_N], ifc[MAX_N];
int Limit, rev[MAX_N], omg[MAX_N], inv[MAX_N];
#define VI vector<int>
void FFT_prepare(int len) {
int p = 0;
for (Limit = 1; Limit <= len; Limit <<= 1) p++;
for (int i = 1; i < Limit; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (p - 1));
omg[0] = 1, omg[1] = fpow(3, (Mod - 1) / Limit);
for (int i = 2; i < Limit; i++) omg[i] = 1ll * omg[i - 1] * omg[1] % Mod;
}
void NTT(VI &p, int op) {
p.resize(Limit);
for (int i = 1; i < Limit; i++) if (i < rev[i]) swap(p[i], p[rev[i]]);
for (int i = 1, t = Limit >> 1; i < Limit; i <<= 1, t >>= 1)
for (int j = 0; j < Limit; j += i << 1)
for (int k = 0, o = 0; k < i; k++, o += t) {
int x = p[j + k], y = 1ll * omg[o] * p[i + j + k] % Mod;
p[j + k] = (x + y) % Mod, p[i + j + k] = (x - y + Mod) % Mod;
}
if (!op) {
reverse(p.begin() + 1, p.end());
for (int i = 0; i < Limit; i++) p[i] = 1ll * p[i] * inv[Limit] % Mod;
}
}
void Poly_Der(int n, VI &a, VI &b) {
b.resize(n - 1);
for (int i = 0; i < n - 1; i++) b[i] = 1ll * a[i + 1] * (i + 1) % Mod;
}
void Poly_Int(int n, VI &a, VI &b) {
b.resize(n + 1), b[0] = 0;
for (int i = 1; i <= n; i++) b[i] = 1ll * a[i - 1] * inv[i] % Mod;
}
void Poly_Inv(int n, VI a, VI &b) {
if (n == 1) return b.clear(), b.push_back(fpow(a[0], Mod - 2));
Poly_Inv((n + 1) >> 1, a, b);
a.resize(n), b.resize(n);
VI d = b;
FFT_prepare(1.5 * n + 0.5);
NTT(a, 1), NTT(d, 1);
for (int i = 0; i < Limit; i++) a[i] = 1ll * a[i] * d[i] % Mod * d[i] % Mod;
NTT(a, 0);
for (int i = (n + 1) >> 1; i < n; i++) b[i] = Mod - a[i];
d.clear();
}
void Poly_Ln(int n, VI a, VI &b) {
VI c, d; Poly_Inv(n, a, c), Poly_Der(n, a, d);
FFT_prepare(n + n);
NTT(c, 1), NTT(d, 1);
for (int i = 0; i < Limit; i++) c[i] = 1ll * c[i] * d[i] % Mod;
NTT(c, 0);
Poly_Int(n, c, b);
}
void Poly_Exp(int n, VI a, VI &b) {
if (n == 1) return b.clear(), b.push_back(1);
Poly_Exp((n + 1) >> 1, a, b), b.resize(n);
VI c, d = b; Poly_Ln(n, b, c);
FFT_prepare(n + 1);
for (int i = 0; i < n; i++) c[i] = (a[i] - c[i] + (i == 0) + Mod) % Mod;
NTT(c, 1), NTT(d, 1);
for (int i = 0; i < Limit; i++) c[i] = 1ll * c[i] * d[i] % Mod;
NTT(c, 0);
for (int i = (n + 1) >> 1; i < n; i++) b[i] = c[i];
}
int N = 2e5, M;
int a[MAX_N];
VI Div(int l, int r) {
if (l == r) return {1, Mod - a[l]};
int mid = (l + r) >> 1;
VI L = Div(l, mid), R = Div(mid + 1, r);
int len = L.size() + R.size() - 1;
FFT_prepare(len);
NTT(L, 1), NTT(R, 1);
for (int i = 0; i < Limit; i++) L[i] = 1ll * L[i] * R[i] % Mod;
NTT(L, 0), L.resize(len);
return L;
}
VI F, G, H, iF, A, B, C, CR, dCR, iC, LnF, pF, ans;
int main () {
#ifndef ONLINE_JUDGE
freopen("cpp.in", "r", stdin);
#endif
for (int i = fac[0] = 1; i <= N; i++) fac[i] = 1ll * fac[i - 1] * i % Mod;
ifc[N] = fpow(fac[N], Mod - 2);
for (int i = N - 1; ~i; i--) ifc[i] = 1ll * ifc[i + 1] * (i + 1) % Mod;
for (int i = inv[0] = 1; i <= N; i++) inv[i] = 1ll * ifc[i] * fac[i - 1] % Mod;
N = gi(), M = gi();
if (N == 1) return puts(M == 0 ? "1" : "0") & 0;
for (int i = 1; i <= N; i++) a[i] = gi();
//prepare F, G
F.resize(N), G.resize(N);
for (int i = 0; i < N; i++) F[i] = 1ll * ifc[i] * fpow(i + 1, M) % Mod;
for (int i = 0; i < N; i++) G[i] = 1ll * ifc[i] * fpow(i + 1, 2 * M) % Mod;
//prepare iF
Poly_Inv(N, F, iF);
//prepare H
FFT_prepare(N << 1);
A = iF, B = G;
NTT(A, 1), NTT(B, 1), H.resize(Limit);
for (int i = 0; i < Limit; i++) H[i] = 1ll * A[i] * B[i] % Mod;
NTT(H, 0);
//prepare C = sigma 1 / (1 - a[i]x)
C = Div(1, N); CR = C; reverse(CR.begin(), CR.end());
Poly_Der(CR.size(), CR, dCR);
reverse(dCR.begin(), dCR.end());
Poly_Inv(C.size(), C, iC);
FFT_prepare(iC.size() + dCR.size());
NTT(iC, 1), NTT(dCR, 1), C.resize(Limit);
for (int i = 0; i < Limit; i++) C[i] = 1ll * iC[i] * dCR[i] % Mod;
NTT(C, 0);
//prepare prod F(a[i]x)
Poly_Ln(N, F, LnF);
for (int i = 0; i < N; i++) LnF[i] = 1ll * LnF[i] * C[i] % Mod;
Poly_Exp(N, LnF, pF);
//prepare sigma H(a[i]x)
for (int i = 0; i < N; i++) H[i] = 1ll * H[i] * C[i] % Mod;
//getans
FFT_prepare(N << 1);
NTT(pF, 1), NTT(H, 1), ans.resize(Limit);
for (int i = 0; i < Limit; i++) ans[i] = 1ll * pF[i] * H[i] % Mod;
NTT(ans, 0);
int pa = 1;
for (int i = 1; i <= N; i++) pa = 1ll * pa * a[i] % Mod;
printf("%lld\n", 1ll * ans[N - 2] * fac[N - 2] % Mod * pa % Mod);
return 0;
}