【模板】多項式乘法
阿新 • • 發佈:2019-01-12
Description
給定一個\(n\)次多項式\(F(x)\),和一個\(m\)次多項式\(G(x)\)。
請求出\(F(x)\)和\(G(x)\)的卷積。
Input
第一行2個正整數\(n,m\)。
接下來一行\(n+1\)個數字,從低到高表示\(F(x)\)的係數。
接下來一行\(m+1\)個數字,從低到高表示\(G(x)\)的係數。
Output
一行\(n+m+1\)個數字,從低到高表示\(F(x)∗G(x)\)的係數。
Code
FFT
#include<cstdio> #include<algorithm> #include<cstring> #include<complex> #include<cmath> #define cp complex < double > using namespace std; const double pi=acos(-1); int lena,lenb,n,res[4000010]; cp F[4000010],G[4000010],arr[4000010],inv[4000010]; inline int read() { int ans=0,f=-1; char ch=getchar(); while (ch<'0' || ch>'9') { if (ch=='-') f=-1; ch=getchar(); } while (ch>='0' && ch<='9') { ans=ans*10+ch-'0'; ch=getchar(); } return ans; } void init() { for (int i=0;i<n;i++) { arr[i]=cp(cos(2*pi*i/n),sin(2*pi*i/n)); inv[i]=conj(arr[i]); } } void FFT(cp *a,cp *arr) { int lim=0; while ((1<<lim)<n) lim++; for (int i=0;i<n;i++) { int t=0; for (int j=0;j<lim;j++) if ((i>>j) & 1) t|=1<<(lim-j-1); if (i<t) swap(a[i],a[t]); } for (int l=2;l<=n;l*=2) { int m=l/2; for (cp *buf=a;buf!=a+n;buf+=l) for (int i=0;i<m;i++) { cp t=arr[n/l*i]*buf[i+m]; buf[i+m]=buf[i]-t; buf[i]+=t; } } } int main() { lena=read();lenb=read(); lena++;lenb++; for (int i=0;i<lena;i++) F[i].real(read()); for (int i=0;i<lenb;i++) G[i].real(read()); n=1;while (n<(lena+lenb)) n<<=1; init(); FFT(F,arr);FFT(G,arr); for (int i=0;i<n;i++) F[i]*=G[i]; FFT(F,inv); for (int i=0;i<n;i++) res[i]=floor(F[i].real()/n+0.5); for (int i=0;i<lena+lenb-1;i++) printf("%d ",res[i]); return 0; }
NTT
#include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #include<complex> #define cp complex < double > using namespace std; const int Mod=998244353; const int p=3,invp=332748118; int lena,lenb,n,res[4000010]; int F[4000010],G[4000010]; inline int read() { int ans=0,f=-1; char ch=getchar(); while (ch<'0' || ch>'9') { if (ch=='-') f=-1; ch=getchar(); } while (ch>='0' && ch<='9') { ans=ans*10+ch-'0'; ch=getchar(); } return ans; } int fpow(int x,int k) { int ans=1; while (k) { if (k&1) ans=1LL*ans*x%Mod; x=1LL*x*x%Mod; k>>=1; } return ans; } void NTT(int *a,int inv) { int lim=0; while ((1<<lim)<n) lim++; for (int i=0;i<n;i++) { int t=0; for (int j=0;j<lim;j++) if ((i>>j) & 1) t|=1<<(lim-j-1); if (i<t) swap(a[i],a[t]); } for (int l=2;l<=n;l*=2) { int m=l/2,p0=fpow(inv?invp:p,(Mod-1)/l); for (int *buf=a;buf!=a+n;buf+=l) { int pn=1; for (int i=0;i<m;i++) { int t=1LL*pn*buf[i+m]%Mod; buf[i+m]=(buf[i]-t+Mod)%Mod; buf[i]=(buf[i]+t)%Mod; pn=1LL*pn*p0%Mod; } } } } int main() { lena=read(),lenb=read(); lena++;lenb++; n=1; while (n<(lena+lenb)) n<<=1; for (int i=0;i<lena;i++) F[i]=read(); for (int i=0;i<lenb;i++) G[i]=read(); NTT(F,0);NTT(G,0); for (int i=0;i<n;i++) F[i]=1LL*F[i]*G[i]%Mod; NTT(F,1); int invn=fpow(n,Mod-2); for (int i=0;i<lena+lenb-1;i++) printf("%d ",1LL*F[i]*invn%Mod); return 0; }