LOJ#2541. 「PKUWC2018」獵人殺 容斥+分治NTT
阿新 • • 發佈:2020-07-24
真——分治NTT
code:
#include <cstdio> #include <vector> #include <cstring> #include <algorithm> #define N 100009 #define ll long long #define mod 998244353 #define pb push_back #define setIO(s) freopen(s".in","r",stdin) using namespace std; int A[N<<2],B[N<<2],n,w[N],fac[N],inv[N<<1]; void init() { fac[0]=1; for(int i=1;i<N;++i) fac[i]=(ll)fac[i-1]*i%mod; inv[1]=1; for(int i=2;i<(N<<1);++i) { inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; } inv[0]=1; } int qpow(int x,int y) { int tmp=1; for(;y;y>>=1,x=(ll)x*x%mod) { if(y&1) tmp=(ll)tmp*x%mod; } return tmp; } int get_inv(int x) { return qpow(x,mod-2); } void NTT(int *a,int len,int op) { for(int i=0,k=0;i<len;++i) { if(i>k) swap(a[i],a[k]); for(int j=len>>1;(k^=j)<j;j>>=1); } for(int l=1;l<len;l<<=1) { int wn=qpow(3,(mod-1)/(l<<1)); if(op==-1) { wn=get_inv(wn); } for(int i=0;i<len;i+=l<<1) { int w=1; for(int j=0;j<l;++j) { int x=a[i+j],y=(ll)w*a[i+j+l]%mod; a[i+j]=(ll)(x+y)%mod; a[i+j+l]=(ll)(x-y+mod)%mod; w=(ll)w*wn%mod; } } } if(op==-1) { int iv=get_inv(len); for(int i=0;i<len;++i) { a[i]=(ll)a[i]*iv%mod; } } } struct poly { int len; vector<int>a; poly() { len=0,a.clear();} void push(int x) { ++len,a.pb(x); } void resi(int x) { while(len<x) a.pb(0),++len; } poly operator*(const poly &b) const { int l=len+b.len,lim; for(lim=1;lim<l;lim<<=1); for(int i=0;i<len;++i) A[i]=a[i]; for(int i=len;i<lim;++i) A[i]=0; for(int i=0;i<b.len;++i) B[i]=b.a[i]; for(int i=b.len;i<lim;++i) B[i]=0; NTT(A,lim,1),NTT(B,lim,1); for(int i=0;i<lim;++i) { A[i]=(ll)A[i]*B[i]%mod; } NTT(A,lim,-1); poly c; for(int i=0;i<len+b.len-1;++i) { c.push(A[i]); } return c; } }F,G; poly solve(int l,int r) { if(l==r) { poly c; c.push(1); c.resi(w[l]); c.push(mod-1); return c; } int mid=(l+r)>>1; return solve(l,mid)*solve(mid+1,r); } int main() { //setIO("input"); scanf("%d",&n); int tot=0; for(int i=1;i<=n;++i) { scanf("%d",&w[i]); tot+=w[i]; } init(); poly ans=solve(2,n); int fin=0; for(int i=0;i<=tot-w[1];++i) { (fin+=(ll)w[1]*ans.a[i]%mod*inv[i+w[1]]%mod)%=mod; } printf("%d\n",fin); return 0; }