【模板】FFT&NTT
阿新 • • 發佈:2019-01-04
Tblog:
https://oi.men.ci/fft-notes/
https://oi.men.ci/fft-to-ntt/
FFT:
#include<bits/stdc++.h> using namespace std; typedef complex<double> D; const double PI=acos(-1); const int N=4e6+20; int n,n1,n2,zws; int a[N],b[N],res[N]; D c1[N],c2[N],omega[N],iomega[N]; inline void mread(int &rx) {int fx=1;char c=getchar(); rx=0; while(c<48||c>57) { if(c=='-') fx=-1; c=getchar(); } while(c>=48&&c<=57) { rx=rx*10+c-48; c=getchar(); } rx*=fx; } inline void init() { int i,j; for(i=0;i<n;i++) omega[i]=D(cos(2*PI/n*i),sin(2*PI/n*i)); for(i=0;i<n;i++) iomega[i]=D(cos(-2*PI/n*i),sin(-2*PI/n*i)); } inline void FFT(D *w,int flag) { int i,j,t,len,m; D nw,bw; for(i=0;i<n;i++) { t=0; for(j=0;j<zws;j++) if(i&(1<<j)) t|=(1<<(zws-j-1)); if(i>t) std::swap(w[i],w[t]); }for(len=2;len<=n;len<<=1) { m=(len>>1); for(i=0;i<n;i+=len) { for(j=0;j<m;j++) { nw= flag==1? omega[n/len*j]:iomega[n/len*j]; //printf("%.2lf %.2lf\n",omega[n/len*j].real(),iomega[n/len*j].real()); bw=nw*w[i+m+j]; w[i+m+j]=w[i+j]-bw; w[i+j]+=bw; } } } } int main() { //freopen("test.in","r",stdin); int i,j; mread(n1);mread(n2); n1++;n2++; for(i=0;i<n1;i++) mread(a[i]); for(i=0;i<n2;i++) mread(b[i]); for(i=0;i<n1;i++) c1[i].real(a[i]); for(i=0;i<n2;i++) c2[i].real(b[i]); for(n=1,zws=0;n<n1+n2;n<<=1,zws++); init(); FFT(c1,1);FFT(c2,1); for(i=0;i<n;i++) c1[i]*=c2[i]; FFT(c1,-1); for(i=0;i<n1+n2-1;i++) c1[i]/=n; for(i=0;i<n1+n2-1;i++) res[i]=floor(c1[i].real()+0.5); for(i=0;i<n1+n2-1;i++) printf("%d ",res[i]); return 0; }
NTT:
#include<bits/stdc++.h> using namespace std; typedef long long LL; const LL N=4e6+20,Mod=998244353,G=3,Gi=332748118; LL n,n1,n2,mws,ny; LL c1[N],c2[N]; inline void mread(LL &rx) { LL fx=1;char c=getchar(); rx=0; while(c<48||c>57) { if(c=='-') fx=-1; c=getchar(); } while(c>=48&&c<=57) { rx=rx*10+c-48; c=getchar(); } rx*=fx; } inline LL mmul(LL a,LL b){ return (a*b)%Mod;} inline LL madd(LL a,LL b){ return (a+b)%Mod;} inline LL msub(LL a,LL b){ return ((a-b)%Mod+Mod)%Mod;} inline LL mquery(LL x,LL bs) { LL rans=1; while(bs>0) { if(bs&1LL) rans=mmul(rans,x); x=mmul(x,x); bs>>=1; } return rans; } inline void NTT(LL *a,LL flag) { LL i,j,len,m,t,bw,w,der; for(i=0;i<n;i++) { t=0; for(j=0;j<mws;j++) if(i&(1LL<<j)) t|=(1LL<<(mws-j-1)); if(i<t) swap(a[i],a[t]); } for(len=2;len<=n;len<<=1) { m=(len>>1); if(flag==1) bw=mquery(G,(Mod-1)/len); else bw=mquery(Gi,(Mod-1)/len); for(i=0;i<n;i+=len) { for(j=0,w=1;j<m;j++,w=mmul(w,bw)) { der=mmul(w,a[i+m+j]); a[i+m+j]=msub(a[i+j],der); a[i+j]=madd(a[i+j],der); } } } } int main() { //freopen("test.in","r",stdin); LL i,j; mread(n1);mread(n2); n1++;n2++; for(i=0;i<n1;i++) mread(c1[i]); for(i=0;i<n2;i++) mread(c2[i]); for(n=1,mws=0;n<n1+n2-1;n<<=1,mws++); NTT(c1,1);NTT(c2,1); for(i=0;i<n;i++) c1[i]=mmul(c1[i],c2[i]); NTT(c1,-1); ny=mquery(n,Mod-2); for(i=0;i<n1+n2-1;i++) { printf("%lld ",mmul(c1[i],ny)); } return 0; }