1. 程式人生 > 其它 >[題解] Atcoder ABC 225 H Social Distance 2 生成函式,分治FFT

[題解] Atcoder ABC 225 H Social Distance 2 生成函式,分治FFT

題目
首先還沒有安排座位的\(m-k\)個人之間是有順序的,所以先把答案乘上\((m-k)!\),就可以把這些人看作不可區分的。

已經確定的k個人把所有座位分成了k+1段。對於第i段,如果我們能求出這一段恰好額外坐j人時的代價總和\(f_{i,j}\),並令\(f_{i,j}\)的普通生成函式為\(F_i(x)\),答案就是\(\prod F_i(x)\)\(m-k\)次項係數。


先考慮k+1段中兩邊都有已經確定的人的k-1段。對於每一段i,列舉其中額外坐的人數j,現在考慮求出\(f_{i,j}\)。令\(g_i\)表示只考慮兩個相鄰的人,他們之間的距離為i時的代價,顯然\(g_i=i\)

。令\(G(x)\)為g的普通生成函式:

\[\begin{align} G(x)&=\sum_{n>0}n\cdot x^n\\ &=x\sum_{n\geq0}(n+1)x^n\\ &=x \cdot \frac1{(1-x)^2}\\ &=\frac x {(1-x)^2}\\ \end{align} \]

令第i段空座位兩邊端點之間的距離為len,發現\(f_{i,j}=G(x)^{j+1}\)的len次項係數(每j個人有j+1段空隙)。由於\(n\leq2e5\),所以可以對每一個j(\(j\geq0\))用這個公式暴力算:\(\frac{1}{(1-x)^m}=\sum_{n\geq0}\binom{n+m-1}{m-1}x^n\)


考慮序列兩頭只有一邊有已經確定的人的段。這裡\(f_{i,j}=G(x)^j\)的0~len次項係數之和。根據上面的公式,我們實際要求的是一個組合數字首和的形式。\(C(n,n)+C(n+1,n)+C(n+2,n)+\cdots+C(m,n)=C(m+1,n+1)\),可以根據這個直接\(O(1)\)算。


對於k=0的情況特殊處理,方法和上面處理序列兩頭的類似。

所以現在已經算出了每一段的\(F(x)\),項數之和是\(O(n)\)的,用分治FFT把所有\(F(x)\)捲起來即可。卷之前把所有\(F(x)\)順序打亂,防止分治的時候被卡。時間複雜度\(O(nlog^2n)\)

點選檢視程式碼
#include <bits/stdc++.h>
#include <atcoder/all>

#define rep(i,n) for(int i=0;i<n;++i)
#define repn(i,n) for(int i=1;i<=n;++i)
#define LL long long
#define pii pair <int,int>
#define fi first
#define se second
#define mpr make_pair
#define pb push_back

using namespace std;
using mint=atcoder::modint998244353;

const LL MOD=998244353;

LL qpow(LL x,LL a)
{
	LL res=x,ret=1;
	while(a>0)
	{
		if((a&1)==1) ret=ret*res%MOD;
		a>>=1;
		res=res*res%MOD;
	}
	return ret;
}

LL n,m,k,a[200010],fac[400010],inv[400010];
vector <vector <mint> > v;

LL C(LL nn,LL mm){return fac[nn]*inv[mm]%MOD*inv[nn-mm]%MOD;}

void deal(LL emp)
{
  if(emp<=0) return;
  vector <mint> tmp;tmp.pb(1);
  repn(seg,emp) tmp.pb(C(emp+seg,seg+seg));
  v.pb(tmp);
}

vector <mint> solve(LL lb,LL ub)
{
  if(lb==ub) return v[lb];
  return atcoder::convolution(solve(lb,(lb+ub)/2),solve((lb+ub)/2+1,ub));
}

int main()
{
  fac[0]=1;repn(i,400005) fac[i]=fac[i-1]*(LL)i%MOD;
  rep(i,400003) inv[i]=qpow(fac[i],MOD-2);
  cin>>n>>m>>k;
  rep(i,k) scanf("%lld",&a[i]);
  if(k==0)
  {
    LL seg=m-1,res=0;
    for(LL len=m-1;len<n;++len) (res+=C(len+seg-1,seg+seg-1)*(n-len))%=MOD;
    cout<<res*fac[m]%MOD<<endl;
    return 0;
  }
  rep(i,k-1)
  {
    LL len=a[i+1]-a[i];
    vector <mint> tmp;
    repn(seg,len)
    {
      LL val=C(len+seg-1,seg+seg-1);
      tmp.pb(val);
    }
    v.pb(tmp);
  }
  deal(a[0]-1);deal(n-a[k-1]);
  random_shuffle(v.begin(),v.end());
  vector <mint> ans=solve(0,v.size()-1);
  cout<<(LL)ans[m-k].val()*fac[m-k]%MOD<<endl;
	return 0;
}