題解 大朋友和多叉樹
阿新 • • 發佈:2020-07-12
題目大意
給出集合\(S\)和整數\(n\),求出有多少個多叉樹使得每個節點的孩子個數都在\(S\)中,且葉子個數為\(n\)。
思路
啊,居然沒有看出來可以用拉格朗日反演,果然還是自己太菜了。。。
我們設答案的生成函式為\(F\),\(G\)為集合\(S\)的生成函式,可以得到:
\[F=\sum_{i\in S} F^i+x \]
應該很顯然吧,這裡就懶得解釋了。
我們發現這個式子可以化成:
\[F=G(F)+x \]
我們如果設\(C(x)=x-G(c)\),那我們就可以得到:
\[C(F)=x \]
我們就發現\(C\)和\(F\)互為複合逆,於是使用拉格朗日反演可以得到:
\[[x^n]F(x)=\frac{1}{n}[x^{n-1}](\frac{x}{C(x)})^n \]
\[=\frac{1}{n}[x^{n-1}]\exp(n\ln \frac{x}{C(x)}) \]
於是,我們就可以做到\(\Theta(n\log n)\)解決這個問題了。
\(\text {Code}\)
#include <bits/stdc++.h> using namespace std; #define Int register int #define mod 950009857 #define ll long long #define MAXN 400005 #define Gi 7 int quick_pow (int a,int b,int c){ int res = 1;for (;b;b >>= 1,a = 1ll * a * a % c) if (b & 1) res = 1ll * res * a % c; return res; } int limit = 1,l,r[MAXN]; void NTT (int *a,int type){ for (Int i = 0;i < limit;++ i) if (i < r[i]) swap (a[i],a[r[i]]); for (Int mid = 1;mid < limit;mid <<= 1){ int Wn = quick_pow (Gi,(mod - 1) / (mid << 1),mod); if (type == -1) Wn = quick_pow (Wn,mod - 2,mod); for (Int R = mid << 1,j = 0;j < limit;j += R){ for (Int k = 0,w = 1;k < mid;++ k,w = 1ll * w * Wn % mod){ int x = a[j + k],y = 1ll * w * a[j + k + mid] % mod; a[j + k] = (x + y) % mod,a[j + k + mid] = (x + mod - y) % mod; } } } if (type == 1) return ; int Inv = quick_pow (limit,mod - 2,mod); for (Int i = 0;i < limit;++ i) a[i] = 1ll * a[i] * Inv % mod; } int c[MAXN]; void Solve (int len,int *a,int *b) { if (len == 1) return b[0] = quick_pow (a[0],mod - 2,mod),void (); Solve ((len + 1) >> 1,a,b); limit = 1,l = 0; while (limit < len << 1) limit <<= 1,l ++; for (Int i = 0;i < limit;++ i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)); for (Int i = 0;i < len;++ i) c[i] = a[i]; for (Int i = len;i < limit;++ i) c[i] = 0; NTT (c,1);NTT (b,1); for (Int i = 0;i < limit;++ i) b[i] = 1ll * b[i] * (2 + mod - 1ll * c[i] * b[i] % mod) % mod; NTT (b,-1); for (Int i = len;i < limit;++ i) b[i] = 0; } void deravitive (int *a,int n){ for (Int i = 1;i <= n;++ i) a[i - 1] = 1ll * a[i] * i % mod; a[n] = 0; } void inter (int *a,int n){ for (Int i = n;i >= 1;-- i) a[i] = 1ll * a[i - 1] * quick_pow (i,mod - 2,mod) % mod; a[0] = 0; } int b[MAXN]; void Ln (int *a,int n){ memset (b,0,sizeof (b)); Solve (n,a,b);deravitive (a,n); while (limit < (n << 1)) limit <<= 1,l ++; for (Int i = 0;i < limit;++ i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1)); NTT (a,1),NTT (b,1); for (Int i = 0;i < limit;++ i) a[i] = 1ll * a[i] * b[i] % mod; NTT (a,-1),inter (a,n); for (Int i = n + 1;i < limit;++ i) a[i] = 0; } int F0[MAXN]; void Exp (int *a,int *B,int n){ if (n == 1) return B[0] = 1,void (); Exp (a,B,(n + 1) >> 1); for (Int i = 0;i < limit;++ i) F0[i] = B[i]; Ln (F0,n); F0[0] = (a[0] + 1 + mod - F0[0]) % mod; for (Int i = 1;i < n;++ i) F0[i] = (a[i] + mod - F0[i]) % mod; NTT (F0,1);NTT (B,1); for (Int i = 0;i < limit;++ i) B[i] = 1ll * F0[i] * B[i] % mod; NTT (B,-1); for (Int i = n;i < limit;++ i) B[i] = 0; } template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;} template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);} template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');} int s,m,G[MAXN],G_[MAXN]; signed main(){ read (s,m); for (Int i = 1,x;i <= m;++ i) read (x),G[x - 1] = mod - 1; G[0] = 1,Solve (s + 1,G,G_),memset (G,0,sizeof (G)),Ln (G_,s); for (Int i = 0;i <= s;++ i) G_[i] = 1ll * G_[i] * s % mod; Exp (G_,G,s + 1),write (1ll * G[s - 1] * quick_pow (s,mod - 2,mod) % mod),putchar ('\n'); return 0; }