1. 程式人生 > 實用技巧 >Easy【生成函式】

Easy【生成函式】

題意

若序列 \(A,B\) 滿足 \(\sum_{i=1}^{K}{a_i}=N,\sum_{i=1}^{K}{b_i}=M\) ,則其對答案的貢獻是:\(P=\prod_{i=1}^{K}{\min(a_i,b_i)}\),問所有滿足條件的序列的總貢獻為多少。

\(1\leq N,M \leq 10^6,1\leq K \leq \min(N,M)\)

https://ac.nowcoder.com/acm/contest/5670/C

分析

如果本題求的是構造的序列 \(A,B\) 的方案總數,那麼可以構造下列的生成函式:

\[S=(x+x^2+x^3+\dots +x^N)^K*(y+y^2+y^3+\dots +y^M)^K \]

答案為展開式中 \(x^Ny^M\) 的係數。

但題目要求:

\[\prod_{i=1}^{K}{\min(a_i,b_i)} \]

因此,可以構造生成函式:?

\[S=\sum_{i,j\in [1,\infty)}{\min(i,j)x^iy^j} \]

那麼,最終的答案為 \(S^K\) 的展開式中 \(x^Ny^M\) 的係數。

\[\begin{align} S &= xy+xy^2+xy^3+\dots \\ &+x^2y+2x^2y^2+2x^2y^3+\dots\\ &+x^3y+2x^3y^2+3x^3y^3+\dots\\ \end{align} \]

兩邊同時乘上 \(x\)

,有:

\[\begin{align} xS &= 0+0+0+\dots\\ &+ x^2y+x^2y^2+x^2y^3+\dots\\ &+ x^3y+2x^3y^2+2x^3y^3+\dots\\ \end{align} \]

兩式相減,得:

\[\begin{align} S-xS &= xy+xy^2+xy^3+\dots\\ &+0+x^2y^2+x^2y^3+\dots\\ &+0+0+x^3y^3+\dots \end{align} \]

\(f(1)=xy(1+y+y^2+y^3+\dots)\ ,\ f(n)=xyf(n-1)\),則:

\[S-xS=\sum_{i=1}^{\infty}{f(i)}=f(1)*(1+xy+x^2y^2+\dots)=xy*(1+y+y^2+\dots)*(1+xy+x^2y^2+\dots) \]

\(G(x)=1+x+x^2+x^3+\dots\),那麼 \(S(1-x)=xy*G(y)*G(xy)\)

因為 \(G(x)=xG(x)+1\),即 \(G(x)=\frac{1}{1-x}\),因此:\(S=xy*G(x)*G(y)*G(xy)\),所以有:

\[S^K=x^Ky^KG(x)^KG(y)^KG(xy)^K \]

又根據廣義二項式定理

\[\frac{1}{(1-x)^n}=\sum_{i=0}^{\infty}{C_{n+i-1}^{i-1}x^i} \]

\[G(x)^K=\sum_{i=0}^{\infty}{C_{K+i-1}^{i-1}{x^i}} \]

在生成函式中,\(x^Ny^M\) 的係數為答案,而多項式前面已經有了 \(x^Ky^K\)。因此,可以在 \([0,\min(N,M)-K]\) 內列舉 \(xy\) 的係數,然後根據 \(N\)\(M\) 補全出 \(x\)\(y\) 的係數,最終得到答案:

\[ans=\sum_{i=0}^{\min(N,M)-K}{C_{K+(N-K-i)-1}^{K-1}*C_{K+(M-K-i)-1}^{K-1}*C_{K+i-1}^{K-1}} \]

程式碼

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int mod=998244353;
const int N=1e6+6;
ll fac[N],inv[N];
ll power(ll a,ll b)
{
    ll res=1;
    a%=mod;
    while(b)
    {
        if(b&1) res=res*a%mod;
        a=a*a%mod;
        b>>=1;
    }
    return res;
}
void init()
{
    int maxn=1e6;
    fac[0]=1;
    for(int i=1;i<=maxn;i++)
        fac[i]=fac[i-1]*i%mod;
    inv[maxn]=power(fac[maxn],mod-2);
    for(int i=maxn-1;i>=0;i--)
        inv[i]=1LL*inv[i+1]*(i+1)%mod;
}
int main()
{
    int T,n,m,k;
    init();
    scanf("%d",&T);
    while(T--)
    {
        scanf("%d%d%d",&n,&m,&k);
        ll ans=0;
        int minn=min(n,m);
        for(int i=0;i<=minn-k;i++)//列舉
        {
            ll t1=fac[k+i-1]*inv[k-1]%mod*inv[i]%mod;
            ll t2=fac[n-i-1]*inv[k-1]%mod*inv[n-k-i]%mod;
            ll t3=fac[m-i-1]*inv[k-1]%mod*inv[m-k-i]%mod;
            ans=(ans+t1*t2%mod*t3)%mod;
        }
        printf("%lld\n",ans);
    }
    return 0;
}

參考部落格:https://zhuanlan.zhihu.com/p/234938833