[TJOI2019] 洛谷P5339 唱、跳、rap和籃球
問題描述
大中鋒的學院要組織學生參觀博物館,要求學生們在博物館中排成一隊進行參觀。他的同學可以分為四類:一部分最喜歡唱、一部分最喜歡跳、一部分最喜歡rap,還有一部分最喜歡籃球。如果佇列中k,k + 1,k + 2,k + 3位置上的同學依次,最喜歡唱、最喜歡跳、最喜歡rap、最喜歡籃球,那麼他們就會聚在一起討論蔡徐坤。大中鋒不希望這種事情發生,因為這會使得隊伍顯得很亂。大中鋒想知道有多少種排隊的方法,不會有學生聚在一起討論蔡徐坤。兩個學生隊伍被認為是不同的,當且僅當兩個隊伍中至少有一個位置上的學生的喜好不同。由於合法的隊伍可能會有很多種,種類數對998244353取模。
輸入格式
輸入資料只有一行。每行5個整數,第一個整數n,代表大中鋒的學院要組織多少人去參觀博物館。接下來四個整數a、b、c、d,分別代表學生中最喜歡唱的人數、最喜歡跳的人數、最喜歡rap的人數和最喜歡籃球的人數。保證\(a+b+c+d \ge n\)
輸出格式
每組資料輸出一個整數,代表你可以安排出多少種不同的學生隊伍,使得隊伍中沒有學生聚在一起討論蔡徐坤。結果對998244353取模。
樣例輸入
4 4 3 2 1
樣例輸出
174
資料範圍
對於20%的資料,有\(n=a=b=c=d\le500\)
對於100%的資料,有\(n \le 1000\), \(a, b, c, d \le 500\)
解析
我們可以考慮容斥。設 \(f_i\) 表示至少有 \(i\) 組學生會討論蔡徐坤,那麼我們只需要確定每組的第一個學生的位置即可,即
\[f_i=C_{n-3i}^i \]
對於剩下的 \(n-4i\) 個位置,我們可以隨便排列。我們可以發現,這其實是一個可重排列,但是要滿足排列長度等於 \(n-4i\)
\[\begin{align} g_i &= \sum_{x1\le a-i,x2\le b-i,x3\le c-i,x4\le d-i} \frac{(n-4i)!}{x1!x2!x3!x4!}\ \ [x1+x2+x3+x4=n-4i]\\ &= (n-4i)!\sum_{x1\le a-i,x2\le b-i,x3\le c-i,x4\le d-i} \frac{1}{x1!x2!x3!x4!}\ \ [x1+x2+x3+x4=n-4i] \end{align} \]
實際上,\(g_i\)
\[Ans=\sum_{i=0}^{n/4} (-1)^i f_i g_i \]
其中當 \(i=0\) 時,得到的是總方案數。
程式碼
#include <iostream>
#include <cstdio>
#define int long long
#define N 10002
using namespace std;
const int mod=998244353;
const int G=3;
int n,a,b,c,d,i,j,fac[N],inv[N],r[N],A[N],B[N],C[N],D[N],ans;
int read()
{
char c=getchar();
int w=0;
while(c<'0'||c>'9') c=getchar();
while(c<='9'&&c>='0'){
w=w*10+c-'0';
c=getchar();
}
return w;
}
int poww(int a,int b)
{
int ans=1,base=a;
while(b){
if(b&1) ans=ans*base%mod;
base=base*base%mod;
b>>=1;
}
return ans;
}
int cal(int n,int m)
{
return fac[n]*inv[m]%mod*inv[n-m]%mod;
}
void NTT(int *a,int inv,int n)
{
for(int i=0;i<n;i++){
if(i<r[i]) swap(a[i],a[r[i]]);
}
for(int l=2;l<=n;l<<=1){
int mid=l/2;
int cur=poww(G,(mod-1)/l);
if(inv==-1) cur=poww(cur,mod-2);
for(int i=0;i<n;i+=l){
int omg=1;
for(int j=0;j<mid;j++,omg=omg*cur%mod){
int tmp=omg*a[i+j+mid]%mod;
a[i+j+mid]=(a[i+j]-tmp+mod)%mod;
a[i+j]=(a[i+j]+tmp)%mod;
}
}
}
if(inv==-1){
for(int i=0;i<n;i++) a[i]=a[i]*poww(n,mod-2)%mod;
}
}
signed main()
{
n=read();a=read();b=read();c=read();d=read();
for(i=fac[0]=1;i<=n;i++) fac[i]=fac[i-1]*i%mod;
inv[n]=poww(fac[n],mod-2);
for(i=n-1;i>=0;i--) inv[i]=inv[i+1]*(i+1)%mod;
for(i=0;i<=n/4;i++){
if(a<i||b<i||c<i||d<i) continue;
int m=1,lim=0;
while(m<a+b+c+d-4*i) m<<=1,lim++;
for(j=0;j<m;j++) r[j]=(r[j>>1]>>1)|((j&1)<<(lim-1));
for(j=0;j<m;j++) A[j]=(j<=a-i)?inv[j]:0;
for(j=0;j<m;j++) B[j]=(j<=b-i)?inv[j]:0;
for(j=0;j<m;j++) C[j]=(j<=c-i)?inv[j]:0;
for(j=0;j<m;j++) D[j]=(j<=d-i)?inv[j]:0;
NTT(A,1,m);NTT(B,1,m);NTT(C,1,m);NTT(D,1,m);
for(j=0;j<m;j++) A[j]=A[j]*B[j]%mod*C[j]%mod*D[j]%mod;
NTT(A,-1,m);
int tmp=cal(n-3*i,i)*A[n-4*i]%mod*fac[n-4*i]%mod;
if(i%2==0) ans=(ans+tmp)%mod;
else ans=(ans-tmp+mod)%mod;
}
printf("%lld\n",ans);
return 0;
}