LOJ#3120. 珍珠 容斥+生成函式+NTT
阿新 • • 發佈:2020-07-22
神仙多項式可還行.
code:
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define N 100009 #define ll long long #define mod 998244353 #define setIO(s) freopen(s".in","r",stdin) using namespace std; int fac[N],inv[N],A[N<<2],B[N<<2],f[N],g[N],lw[N]; int ADD(int x,int y) { return (ll)(x+y)%mod; } int DEC(int x,int y) { return (ll)(x-y+mod)%mod; } int MUL(int x,int y) { return (ll)x*y%mod; } int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) if(y&1) { tmp=(ll)tmp*x%mod; } return tmp; } int get_inv(int x) { return qpow(x,mod-2); } void NTT(int *a,int len,int op) { for(int i=0,k=0;i<len;++i) { if(i>k) { swap(a[i],a[k]); } for(int j=len>>1;(k^=j)<j;j>>=1); } for(int l=1;l<len;l<<=1) { int wn=qpow(3,(mod-1)/(l<<1)); if(op==-1) { wn=get_inv(wn); } for(int i=0;i<len;i+=l<<1) { int w=1,x,y; for(int j=0;j<l;++j) { x=a[i+j],y=(ll)w*a[i+j+l]%mod; a[i+j]=(ll)(x+y)%mod; a[i+j+l]=(ll)(x-y+mod)%mod; w=(ll)w*wn%mod; } } } if(op==-1) { int iv=get_inv(len); for(int i=0;i<len;++i) { a[i]=(ll)a[i]*iv%mod; } } } void init() { fac[0]=lw[0]=1; lw[1]=get_inv(2); for(int i=1;i<N;++i) { fac[i]=(ll)fac[i-1]*i%mod; lw[i]=(ll)lw[i-1]*lw[1]%mod; } inv[1]=1; for(int i=2;i<N;++i) { inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; } inv[0]=1; for(int i=1;i<N;++i) inv[i]=(ll)inv[i-1]*inv[i]%mod; } int C(int x,int y) { return (ll)fac[x]*inv[y]%mod*inv[x-y]%mod; } int main() { // setIO("input"); int n,m,D,lim; scanf("%d%d%d",&D,&n,&m); if(n-2*m<0) { printf("0\n"); return 0; } if(n-2*m>=D) { printf("%d\n",qpow(D,n)); return 0; } init(); for(lim=1;lim<=(D<<1);lim<<=1); for(int i=0;i<=D;++i) { int d=(i&1)?mod-1:1; A[i]=(ll)d*inv[i]%mod*qpow(DEC(D,2*i),n)%mod; B[i]=inv[i]; } NTT(A,lim,1),NTT(B,lim,1); for(int i=0;i<lim;++i) { A[i]=(ll)A[i]*B[i]%mod; } NTT(A,lim,-1); for(int i=0;i<=D;++i) { f[i]=(ll)A[i]*C(D,i)%mod*fac[i]%mod*lw[i]%mod; } for(int i=0;i<lim;++i) { A[i]=B[i]=0; } for(int i=0;i<=D;++i) { int d=(i&1)?mod-1:1; A[i]=(ll)f[D-i]*fac[D-i]%mod; B[i]=(ll)d*inv[i]%mod; } NTT(A,lim,1),NTT(B,lim,1); for(int i=0;i<lim;++i) { A[i]=(ll)A[i]*B[i]%mod; } NTT(A,lim,-1); int ans=0; for(int i=0;i<=n-2*m;++i) { ans=ADD(ans,(ll)inv[i]*A[D-i]%mod); } printf("%d\n",ans); return 0; }