1. 程式人生 > 實用技巧 >LOJ#3120. 珍珠 容斥+生成函式+NTT

LOJ#3120. 珍珠 容斥+生成函式+NTT

神仙多項式可還行.

code:

#include <cstdio>  
#include <vector>
#include <cstring>
#include <algorithm>    
#define N 100009  
#define ll long long 
#define mod 998244353
#define setIO(s) freopen(s".in","r",stdin) 
using namespace std;   
int fac[N],inv[N],A[N<<2],B[N<<2],f[N],g[N],lw[N];    
int ADD(int x,int y) { 
    return (ll)(x+y)%mod;  
} 
int DEC(int x,int y) { 
    return (ll)(x-y+mod)%mod; 
} 
int MUL(int x,int y) {  
    return (ll)x*y%mod; 
}
int qpow(int x,int y) { 
    int tmp=1; 
    for(;y;y>>=1,x=(ll)x*x%mod) 
        if(y&1) {   
            tmp=(ll)tmp*x%mod; 
        } 
    return tmp; 
} 
int get_inv(int x) { 
    return qpow(x,mod-2); 
}  
void NTT(int *a,int len,int op) { 
    for(int i=0,k=0;i<len;++i) { 
        if(i>k) { 
            swap(a[i],a[k]); 
        } 
        for(int j=len>>1;(k^=j)<j;j>>=1); 
    }  
    for(int l=1;l<len;l<<=1) { 
        int wn=qpow(3,(mod-1)/(l<<1)); 
        if(op==-1) {    
            wn=get_inv(wn); 
        } 
        for(int i=0;i<len;i+=l<<1) { 
            int w=1,x,y;  
            for(int j=0;j<l;++j) {  
                x=a[i+j],y=(ll)w*a[i+j+l]%mod;  
                a[i+j]=(ll)(x+y)%mod; 
                a[i+j+l]=(ll)(x-y+mod)%mod;   
                w=(ll)w*wn%mod;  
            }
        }
    }
    if(op==-1) { 
        int iv=get_inv(len); 
        for(int i=0;i<len;++i) {    
            a[i]=(ll)a[i]*iv%mod;  
        }
    }
}
void init() {
    fac[0]=lw[0]=1; 
    lw[1]=get_inv(2);  
    for(int i=1;i<N;++i) { 
        fac[i]=(ll)fac[i-1]*i%mod; 
        lw[i]=(ll)lw[i-1]*lw[1]%mod;  
    }  
    inv[1]=1; 
    for(int i=2;i<N;++i) { 
        inv[i]=(ll)(mod-mod/i)*inv[mod%i]%mod; 
    } 
    inv[0]=1; 
    for(int i=1;i<N;++i) inv[i]=(ll)inv[i-1]*inv[i]%mod; 
}
int C(int x,int y) { 
    return (ll)fac[x]*inv[y]%mod*inv[x-y]%mod;   
}
int main() { 
    // setIO("input");      
    int n,m,D,lim;  
    scanf("%d%d%d",&D,&n,&m);     
    if(n-2*m<0) {  
        printf("0\n");  
        return 0; 
    } 
    if(n-2*m>=D) { 
        printf("%d\n",qpow(D,n)); 
        return 0; 
    }  
    init();  
    for(lim=1;lim<=(D<<1);lim<<=1);  
    for(int i=0;i<=D;++i) { 
        int d=(i&1)?mod-1:1;   
        A[i]=(ll)d*inv[i]%mod*qpow(DEC(D,2*i),n)%mod;   
        B[i]=inv[i];  
    }
    NTT(A,lim,1),NTT(B,lim,1);  
    for(int i=0;i<lim;++i) {    
        A[i]=(ll)A[i]*B[i]%mod; 
    } 
    NTT(A,lim,-1);  
    for(int i=0;i<=D;++i) { 
        f[i]=(ll)A[i]*C(D,i)%mod*fac[i]%mod*lw[i]%mod;  
    }    
    for(int i=0;i<lim;++i) {    
        A[i]=B[i]=0; 
    }
    for(int i=0;i<=D;++i) {  
        int d=(i&1)?mod-1:1;  
        A[i]=(ll)f[D-i]*fac[D-i]%mod;   
        B[i]=(ll)d*inv[i]%mod;   
    }
    NTT(A,lim,1),NTT(B,lim,1); 
    for(int i=0;i<lim;++i) { 
        A[i]=(ll)A[i]*B[i]%mod;  
    }   
    NTT(A,lim,-1);    
    int ans=0;   
    for(int i=0;i<=n-2*m;++i) { 
        ans=ADD(ans,(ll)inv[i]*A[D-i]%mod);  
    }         
    printf("%d\n",ans); 
    return 0; 
}