[題解] Atcoder ABC 225 H Social Distance 2 生成函式,分治FFT
題目
首先還沒有安排座位的\(m-k\)個人之間是有順序的,所以先把答案乘上\((m-k)!\),就可以把這些人看作不可區分的。
已經確定的k個人把所有座位分成了k+1段。對於第i段,如果我們能求出這一段恰好額外坐j人時的代價總和\(f_{i,j}\),並令\(f_{i,j}\)的普通生成函式為\(F_i(x)\),答案就是\(\prod F_i(x)\)的\(m-k\)次項係數。
先考慮k+1段中兩邊都有已經確定的人的k-1段。對於每一段i,列舉其中額外坐的人數j,現在考慮求出\(f_{i,j}\)。令\(g_i\)表示只考慮兩個相鄰的人,他們之間的距離為i時的代價,顯然\(g_i=i\)
令第i段空座位兩邊端點之間的距離為len,發現\(f_{i,j}=G(x)^{j+1}\)的len次項係數(每j個人有j+1段空隙)。由於\(n\leq2e5\),所以可以對每一個j(\(j\geq0\))用這個公式暴力算:\(\frac{1}{(1-x)^m}=\sum_{n\geq0}\binom{n+m-1}{m-1}x^n\)
考慮序列兩頭只有一邊有已經確定的人的段。這裡\(f_{i,j}=G(x)^j\)的0~len次項係數之和。根據上面的公式,我們實際要求的是一個組合數字首和的形式。\(C(n,n)+C(n+1,n)+C(n+2,n)+\cdots+C(m,n)=C(m+1,n+1)\),可以根據這個直接\(O(1)\)算。
對於k=0的情況特殊處理,方法和上面處理序列兩頭的類似。
所以現在已經算出了每一段的\(F(x)\),項數之和是\(O(n)\)的,用分治FFT把所有\(F(x)\)捲起來即可。卷之前把所有\(F(x)\)順序打亂,防止分治的時候被卡。時間複雜度\(O(nlog^2n)\)。
點選檢視程式碼
#include <bits/stdc++.h>
#include <atcoder/all>
#define rep(i,n) for(int i=0;i<n;++i)
#define repn(i,n) for(int i=1;i<=n;++i)
#define LL long long
#define pii pair <int,int>
#define fi first
#define se second
#define mpr make_pair
#define pb push_back
using namespace std;
using mint=atcoder::modint998244353;
const LL MOD=998244353;
LL qpow(LL x,LL a)
{
LL res=x,ret=1;
while(a>0)
{
if((a&1)==1) ret=ret*res%MOD;
a>>=1;
res=res*res%MOD;
}
return ret;
}
LL n,m,k,a[200010],fac[400010],inv[400010];
vector <vector <mint> > v;
LL C(LL nn,LL mm){return fac[nn]*inv[mm]%MOD*inv[nn-mm]%MOD;}
void deal(LL emp)
{
if(emp<=0) return;
vector <mint> tmp;tmp.pb(1);
repn(seg,emp) tmp.pb(C(emp+seg,seg+seg));
v.pb(tmp);
}
vector <mint> solve(LL lb,LL ub)
{
if(lb==ub) return v[lb];
return atcoder::convolution(solve(lb,(lb+ub)/2),solve((lb+ub)/2+1,ub));
}
int main()
{
fac[0]=1;repn(i,400005) fac[i]=fac[i-1]*(LL)i%MOD;
rep(i,400003) inv[i]=qpow(fac[i],MOD-2);
cin>>n>>m>>k;
rep(i,k) scanf("%lld",&a[i]);
if(k==0)
{
LL seg=m-1,res=0;
for(LL len=m-1;len<n;++len) (res+=C(len+seg-1,seg+seg-1)*(n-len))%=MOD;
cout<<res*fac[m]%MOD<<endl;
return 0;
}
rep(i,k-1)
{
LL len=a[i+1]-a[i];
vector <mint> tmp;
repn(seg,len)
{
LL val=C(len+seg-1,seg+seg-1);
tmp.pb(val);
}
v.pb(tmp);
}
deal(a[0]-1);deal(n-a[k-1]);
random_shuffle(v.begin(),v.end());
vector <mint> ans=solve(0,v.size()-1);
cout<<(LL)ans[m-k].val()*fac[m-k]%MOD<<endl;
return 0;
}