FFT NTT 模板
阿新 • • 發佈:2019-01-12
NTT:
#include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define N 2000050 #define ll long long #define MOD 998244353 template<typename T> inline void read(T&x) { T f=1,c=0;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}while(ch>='0'&&ch<='9'){c=10*c+ch-'0';ch=getchar();} x = f*c; } ll fastpow(ll x,int y) { ll ret = 1; while(y) { if(y&1)ret=ret*x%MOD; x=x*x%MOD; y>>=1; } return ret; } int n,m,mx,to[2*N],lim=1,l; void ntt(ll *a,int len,int k) {for(int i=0;i<len;i++) if(i<to[i])swap(a[i],a[to[i]]); for(int i=1;i<len;i<<=1) { ll w0 = fastpow(3,(MOD-1)/(i<<1)); for(int j=0;j<len;j+=(i<<1)) { ll w = 1; for(int o=0;o<i;o++,w=w*w0%MOD) { ll w1= a[j+o],w2 = a[j+o+i]*w%MOD; a[j+o] = (w1+w2)%MOD; a[j+o+i] = ((w1-w2)%MOD+MOD)%MOD; } } } if(k==-1) for(int i=1;i<(lim>>1);i++)swap(c[i],c[lim-i]); } ll a[2*N],b[2*N],c[2*N]; int main() { read(n),read(m);mx = max(n,m); for(int i=0;i<=n;i++)read(a[i]); for(int i=0;i<=m;i++)read(b[i]); while(lim<2*mx)lim<<=1,l++; for(int i=1;i<lim;i++)to[i]=((to[i>>1]>>1)|((i&1)<<(l-1))); ntt(a,lim,1),ntt(b,lim,1); for(int i=0;i<lim;i++)c[i]=a[i]*b[i]%MOD; ntt(c,lim,-1); ll inv = fastpow(lim,MOD-2); for(int i=0;i<lim;i++)c[i]=c[i]*inv%MOD; for(int i=0;i<=n+m;i++)printf("%lld ",c[i]); puts(""); return 0; }
FFT:
#include<cmath> #include<cstdio> #include<cstring> #include<algorithm> using namespace std; #define N 2000050 #define ll long long const double Pi = acos(-1.0); template<typename T> inline void read(T&x) { T f=1,c=0;char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();} while(ch>='0'&&ch<='9'){c=10*c+ch-'0';ch=getchar();} x = f*c; } struct cp { double x,y; cp(){} cp(double x,double y):x(x),y(y){} }; cp operator + (cp &a,cp &b) { return cp(a.x+b.x,a.y+b.y); } cp operator - (cp &a,cp &b) { return cp(a.x-b.x,a.y-b.y); } cp operator * (cp &a,cp &b) { return cp(a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x); } int n,m,mx,to[2*N],lim=1,l; void fft(cp *a,int len,int k) { for(int i=0;i<len;i++) if(i<to[i])swap(a[i],a[to[i]]); for(int i=1;i<len;i<<=1) { cp w0(cos(Pi/i),k*sin(Pi/i)); for(int j=0;j<len;j+=(i<<1)) { cp w(1,0); for(int o=0;o<i;o++,w=w*w0) { cp w1 = a[j+o],w2 = a[j+o+i]*w; a[j+o] = w1+w2; a[j+o+i] = w1-w2; } } } } cp a[2*N],b[2*N],c[2*N]; int main() { read(n),read(m);mx = max(n,m); for(int i=0;i<=n;i++)read(a[i].x); for(int i=0;i<=m;i++)read(b[i].x); while(lim<2*mx)lim<<=1,l++; for(int i=1;i<lim;i++)to[i]=((to[i>>1]>>1)|((i&1)<<(l-1))); fft(a,lim,1),fft(b,lim,1); for(int i=0;i<lim;i++)c[i]=a[i]*b[i]; fft(c,lim,-1); for(int i=0;i<=n+m;i++) printf("%lld ",(ll)(c[i].x/lim+0.5)); puts(""); return 0; }