[生成函式]ARC106F
阿新 • • 發佈:2020-10-25
題目大意:
有\(n\)的點,每個點有一個權值\(d_i\),表示這個點上有多少個孔。你可以連線\(n-1\)條邊,每條邊可以連線分別屬於兩個點的兩個孔。一個孔不能連兩條邊。要求最後這樣使這\(n\)個聯通,即找到完全圖的一個生成樹。
\(n\le 2\times 10^5\)
前置知識
來自於prufer序列的一個結論:在完全圖中,如果\(n\)個點的度數分別為\(D_1,D_2,\cdots,D_n\),那麼滿足這種條件的生成樹個數是
\[\frac{(n-2)!}{\prod_{k=1}^n(D_k-1)!} \]吸收/提取恆等式:
\[\binom{n}{m}=\frac{n}{m}\binom{n-1}{m-1} \]題解
考慮在最後的一種生成樹中,\(n\)個點的度數分別為\(D_1,D_2,\cdots,D_n\)。因為每個點在選擇要連線的孔時是有序的,則這種方案的貢獻是:
\[\frac{(n-2)!}{\prod_{k=1}^n(D_k-1)!}\prod_{k=1}^n\binom{d_k}{D_k}D_k! \]注意到\(\sum\limits_{k=1}^nD_k=2(n-1)\),所以事實上,我們可以構造出一個卷積來求答案\(S\)。
\[S=(n-2)![x^{2n-2}]\prod_{k=1}^nF_k(x) \]\[\begin{align*} F_k(x)&=\sum_{k=0}^\infty\binom{d_k}{k}\frac{k!}{(k-1)!}x^k\\ &=\sum_{k=0}^\infty\binom{d_k}{k}k x^k\\ &=\sum_{k=1}^\infty\binom{d_k-1}{k-1}d_kx^k\\ &=d_k\sum_{k=0}^\infty\binom{d_k-1}{k}x^{k+1}\\ &=d_kx\sum_{k=0}^\infty\binom{d_k-1}{k}x^k\\ &=d_kx(1+x)^{d_k-1} \end{align*} \]到了這一步,我們仍然無法快速地求出答案。不過,對於乘積,我們有固定的套路:求\(ln\),求導,積分。
\[\begin{align*} S&=(n-2)![x^{2n-2}]\prod_{k=1}^nd_kx(1+x)^{d_k-1}\\ &=(n-2)![x^{2n-2}]\exp\left(\sum_{k=1}^n\ln(d_kx(1+x)^{d_k-1})\right) \end{align*} \]我們單獨考慮。積分時要記得加回常數。
\[\begin{align*} G_k(x)&=\ln(d_kx(1+x)^{d_k-1})\\ &=\ln d_k+\ln x+(d_k-1)\ln(1+x)\\ G'_k(x)&=\frac{1}{x}+(d_k-1)\frac{1}{1+x}\\ &=\frac{1}{x}+(d_k-1)\sum_{k=0}^\infty(-1)^kx^k\\ G_k(x)&=\int G'_k(x)\mathbb{d}x\\ &=\ln d_k+\ln x+(d_k-1)\sum_{k=0}^\infty(-1)^k\frac{x^{k+1}}{k+1}\\ &=\ln d_k+\ln x+(d_k-1)\sum_{k=1}^\infty\frac{(-1)^{k-1}}{k}x^k \end{align*} \]有趣的是,我們發現後面的部分與\(d_k\)無關。記\(S=\sum\limits_{k=1}^nd_k\),\(T=\prod\limits_{k=1}^nd_k\),我們有:
\[\begin{align*} S&=(n-2)![x^{2n-2}]\exp\left(\sum_{x=1}^n\left(\ln d_k+\ln x+(d_k-1)\sum_{k=1}^\infty\frac{(-1)^{k-1}}{k}x^k\right)\right)\\ &=(n-2)![x^{2n-2}]Tx^n\exp\left((S-n)\sum_{k=1}^\infty\frac{(-1)^{k-1}}{k}x^k\right)\\ &=(n-2)!T[x^{n-2}]\exp\left((S-n)\sum_{k=1}^\infty\frac{(-1)^{k-1}}{k}x^k\right) \end{align*} \]至此,問題解決。套多項式求\(\exp\)模板即可。
時間複雜度\(O(n\log n)\)
程式碼
#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
const ll mod=998244353,g=3;
const int N=524289;
int n,m,rev[N];
ll sn,pn=1,inv[N],gn[2][N],f[N];
inline ll add(ll a,ll b){return a+b>=mod?a+b-mod:a+b;}
inline ll cut(ll a,ll b){return a-b<0?a-b+mod:a-b;}
inline ll mul(ll a,ll b){return a*b%mod;}
ll fpow(ll a,ll b){ll bs=1;while(b){if(b&1)bs=mul(bs,a);a=mul(a,a);b>>=1;}return bs;}
inline void init(){
for(m=1;m<=n;m<<=1);
gn[0][0]=gn[1][0]=inv[1]=1;
gn[0][1]=fpow(g,(mod-1)/(m<<1));
gn[1][1]=fpow(gn[0][1],mod-2);
for(int i=2;i<(m<<1);i++)gn[0][i]=mul(gn[0][i-1],gn[0][1]);
for(int i=2;i<(m<<1);i++)gn[1][i]=mul(gn[1][i-1],gn[1][1]);
for(int i=2;i<=(m<<1);i++)inv[i]=mul(mod-mod/i,inv[mod%i]);
}
void NTT(ll c[],int n,int tp=0){
for(int i=0;i<n;i++){
rev[i]=(rev[i>>1]>>1)|((i&1)*(n>>1));
if(i<rev[i])swap(c[i],c[rev[i]]);
}
for(int i=1;i<n;i<<=1){
for(int j=0;j<n;j+=(i<<1)){
for(int k=0;k<i;k++){
ll x=c[j+k],y=mul(c[j+k+i],gn[tp][m/i*k]);
c[j+k]=add(x,y);
c[j+k+i]=cut(x,y);
}
}
}
}
void INTT(ll c[],int n){
NTT(c,n,1);
for(int i=0;i<n;i++)c[i]=mul(c[i],inv[n]);
}
void inverse(ll c[],int n){
static ll t[N],tma[N];
t[0]=fpow(c[0],mod-2);
for(int k=2;k<=n;k<<=1){
for(int i=0;i<(k<<1);i++)tma[i]=(i<k?c[i]:0);
for(int i=(k>>1);i<(k<<1);i++)t[i]=0;
NTT(tma,k<<1);
NTT(t,k<<1);
for(int i=0;i<(k<<1);i++)t[i]=cut(add(t[i],t[i]),mul(tma[i],mul(t[i],t[i])));
INTT(t,k<<1);
}
for(int i=0;i<n;i++)c[i]=t[i];
}
void derivative(ll c[],int n){for(int i=0;i<n;i++)c[i]=mul(c[i+1],i+1);}
void integrate(ll c[],int n){for(int i=n-1;i>=1;i--)c[i]=mul(c[i-1],inv[i]);c[0]=0;}
void ln(ll c[],int n){
static ll t[N];
for(int i=0;i<(n<<1);i++)t[i]=(i<n?c[i]:0);
derivative(t,n);
inverse(c,n);
NTT(t,n<<1);
NTT(c,n<<1);
for(int i=0;i<(n<<1);i++)c[i]=mul(c[i],t[i]);
INTT(c,n<<1);
for(int i=n;i<(n<<1);i++)c[i]=0;
integrate(c,n);
}
void exp(ll c[],int n){
static ll t[N],ta[N];
t[0]=1;
for(int k=2;k<=n;k<<=1){
for(int i=0;i<(k<<1);i++)ta[i]=t[i];
ln(ta,k);
for(int i=0;i<k;i++)ta[i]=cut(c[i],ta[i]);
ta[0]=add(ta[0],1);
NTT(t,k<<1);
NTT(ta,k<<1);
for(int i=0;i<(k<<1);i++)t[i]=mul(t[i],ta[i]);
INTT(t,k<<1);
for(int i=k;i<(k<<1);i++)t[i]=0;
}
for(int i=0;i<n;i++)c[i]=t[i];
}
int main(){
scanf("%d",&n);
for(int i=1;i<=n;i++){
int x;
scanf("%d",&x);
sn=add(sn,x);
pn=mul(pn,x);
}
sn=cut(sn,n);
init();
for(int i=1;i<m;i++)f[i]=mul(sn,mul(i&1?1:mod-1,inv[i]));
exp(f,m);
ll facn=1;
for(int i=1;i<=n-2;i++)facn=mul(facn,i);
printf("%lld\n",mul(pn,mul(facn,f[n-2])));
}