1. 程式人生 > 實用技巧 >LOJ#6503. 「雅禮集訓 2018 Day4」Magic 題解

LOJ#6503. 「雅禮集訓 2018 Day4」Magic 題解

題目連結

對每個 \(a_i,\) 建出一個多項式 \(F(x) = \sum\limits_{j=1}^{a_i} x^j \binom{a_i-1}{j-1},\) \(j\)次項係數表示這些卡牌被分成\(j\)段的方案數,也表示它們\(a_i-j\)處強制為魔術對的方案數

對它們進行\(EGF\)卷積,最後的結果\(G(x)\)\(n-i\)次項係數 \([x^{n-i}]G(x)\) 即為強制有\(i\)處為魔術對的方案數

最後二項式反演即可得到答案為 \(ans = \sum\limits_{i=k}^{n} (-1)^{i-k} \binom{i}{k} [x^{n-i}]G(x).\)

怎麼求出把\(m\)個長度之和\(=n\)的多項式的卷積的結果呢?

用一個堆記錄當前多項式,每次找長度最短的那兩個卷積起來即可,可以證明覆雜度不超過\(\Theta (n\log^2 n)\)

code :

#include <bits/stdc++.h>
#define LL long long
using namespace std;
template <typename T> void read(T &x){
	static char ch; x = 0,ch = getchar();
	while (!isdigit(ch)) ch = getchar();
	while (isdigit(ch)) x = x * 10 + ch - '0',ch = getchar();
}
inline void write(int x){if (x > 9) write(x/10); putchar(x%10+'0'); }
const int P = 998244353,g = 3,L = 131072,M = 20050,N = 100050; 
inline int power(int x,int y){
	static int r; r = 1; while (y){ if (y&1) r = (LL)r * x % P; x = (LL)x * x % P,y >>= 1; }
	return r;
}
int rt[30],irt[30],R[L];
int inv[L+5],fac[L+5],nfac[L+5];
inline int C(int n,int m){
	return (n<0||m<0||n<m) ? 0 : ((LL)fac[n] * nfac[m] % P * nfac[n-m]) % P;
}
inline int getR(int n){
	static int i,l,Lim; l = 0,Lim = 1; while (Lim <= n) Lim <<= 1,++l;
	for (i = 1; i < Lim; ++i) R[i] = (R[i>>1]>>1) | ((i&1)<<l-1);
	return Lim;
}
inline void NTT(int *A,int n){
	register int i,j,k,l,w,w0,x;
	for (i = 1; i < n; ++i) if (i < R[i]) swap(A[i],A[R[i]]);
	for (i = l = 1; i < n; i <<= 1,++l)
	for (w0 = rt[l],j = 0; j < n; j += i<<1)
	for (w = 1,k = j; k < i+j; ++k,w = (LL)w * w0 % P)
		x = (LL)w * A[k+i] % P,A[k+i] = (A[k]<x)?(A[k]+P-x):(A[k]-x),
		A[k] = (A[k]+x>=P)?(A[k]+x-P):(A[k]+x);
}
inline void iNTT(int *A,int n){
	register int i,j,k,l,w,w0,x;
	for (i = 1; i < n; ++i) if (i < R[i]) swap(A[i],A[R[i]]);
	for (i = l = 1; i < n; i <<= 1,++l)
	for (w0 = irt[l],j = 0; j < n; j += i<<1)
	for (w = 1,k = j; k < i+j; ++k,w = (LL)w * w0 % P)
		x = (LL)w * A[k+i] % P,A[k+i] = (A[k]<x)?(A[k]+P-x):(A[k]-x),
		A[k] = (A[k]+x>=P)?(A[k]+x-P):(A[k]+x);
	for (i = 0,w = inv[n]; i < n; ++i) A[i] = (LL)A[i] * w % P;
}
typedef vector<int> arr;
int F[L],G[L];
inline void Mul(arr &A,arr &B,arr &C){
	int n = A.size()-1,m = B.size()-1,Li = getR(n+m); register int i;
	for (memset(F,0,Li<<2),i = 0; i <= n; ++i) F[i] = A[i];
	for (memset(G,0,Li<<2),i = 0; i <= m; ++i) G[i] = B[i];
	NTT(F,Li); NTT(G,Li); for (i = 0; i < Li; ++i) F[i] = (LL)F[i] * G[i] % P; iNTT(F,Li);
	C.resize(n+m+1); for (i = 0; i <= n+m; ++i) C[i] = F[i];
}
arr A[M<<1]; int cnto;
struct Node{
	int id,len;
	bool operator < (const Node t) const{ return len > t.len; }
}tmp;
priority_queue<Node>H;
int m,n,k,a[M];
inline void build(int n){
	A[++cnto].resize(n+1);
	for (int i = 0; i <= n; ++i) A[cnto][i] = (LL)nfac[i] * C(n-1,i-1) % P;
}
inline void work(){
	int i,id1,id2;
	for (i = 1; i <= cnto; ++i) tmp.id = i,tmp.len = A[i].size(),H.push(tmp);
	while (H.size() > 1){
		id1 = H.top().id,H.pop(),id2 = H.top().id,H.pop();
		Mul(A[id1],A[id2],A[++cnto]);
		tmp.id = cnto,tmp.len = A[cnto].size(),H.push(tmp);
	}
	for (i = 0; i <= n; ++i) F[n-i] = (LL)fac[i] * A[cnto][i] % P;
}
int main(){
	int i,j;
	for (i = 1,j = 2; i <= 25; ++i,j <<= 1) rt[i] = power(3,(P-1)/j),irt[i] = power(rt[i],P-2);
	inv[0] = inv[1] = nfac[0] = fac[0] = nfac[1] = fac[1] = 1;
	for (i = 2; i <= L; ++i){
		fac[i] = (LL)fac[i-1] * i % P;
		inv[i] = (LL)(P-P/i) * inv[P%i] % P;
		nfac[i] = (LL)nfac[i-1] * inv[i] % P;
	}
	read(m),read(n),read(k);
	for (i = 1; i <= m; ++i) read(a[i]),build(a[i]);
	work();
	int ans = 0;
	for (i = k; i <= n; ++i){
		if ((k-i) & 1) ans = (ans + P - (LL)C(i,k) * F[i] % P) % P;
		else ans = (ans + (LL)C(i,k) * F[i] % P) % P;
	}
	cout << ans << '\n';
	return 0;
}