1. 程式人生 > >【模板】多項式乘法

【模板】多項式乘法

Description

給定一個\(n\)次多項式\(F(x)\),和一個\(m\)次多項式\(G(x)\)

請求出\(F(x)\)\(G(x)\)的卷積。

Input

第一行2個正整數\(n,m\)

接下來一行\(n+1\)個數字,從低到高表示\(F(x)\)的係數。

接下來一行\(m+1\)個數字,從低到高表示\(G(x)\)的係數。

Output

一行\(n+m+1\)個數字,從低到高表示\(F(x)∗G(x)\)的係數。

Code

FFT

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<complex>
#include<cmath>
#define cp complex < double >
using namespace std;
const double pi=acos(-1);
int lena,lenb,n,res[4000010];
cp F[4000010],G[4000010],arr[4000010],inv[4000010];
inline int read()
{
    int ans=0,f=-1;
    char ch=getchar();
    while (ch<'0' || ch>'9')
    {
        if (ch=='-') f=-1;
        ch=getchar();
    }
    while (ch>='0' && ch<='9')
    {
        ans=ans*10+ch-'0';
        ch=getchar();
    }
    return ans;
}
void init()
{
    for (int i=0;i<n;i++)
    {
        arr[i]=cp(cos(2*pi*i/n),sin(2*pi*i/n));
        inv[i]=conj(arr[i]);
    }
}
void FFT(cp *a,cp *arr)
{
    int lim=0;
    while ((1<<lim)<n) lim++;
    for (int i=0;i<n;i++)
    {
        int t=0;
        for (int j=0;j<lim;j++)
            if ((i>>j) & 1) t|=1<<(lim-j-1);
        if (i<t) swap(a[i],a[t]);
    }
    for (int l=2;l<=n;l*=2)
    {
        int m=l/2;
        for (cp *buf=a;buf!=a+n;buf+=l)
            for (int i=0;i<m;i++)
            {
                cp t=arr[n/l*i]*buf[i+m];
                buf[i+m]=buf[i]-t;
                buf[i]+=t;
            }
    }
}
int main()
{
    lena=read();lenb=read();
    lena++;lenb++;
    for (int i=0;i<lena;i++) F[i].real(read());
    for (int i=0;i<lenb;i++) G[i].real(read());
    n=1;while (n<(lena+lenb)) n<<=1;
    init();
    FFT(F,arr);FFT(G,arr);
    for (int i=0;i<n;i++) F[i]*=G[i];
    FFT(F,inv);
    for (int i=0;i<n;i++)
        res[i]=floor(F[i].real()/n+0.5);
    for (int i=0;i<lena+lenb-1;i++)
        printf("%d ",res[i]);
    return 0;
}

NTT

#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#include<complex>
#define cp complex < double >
using namespace std;
const int Mod=998244353;
const int p=3,invp=332748118;
int lena,lenb,n,res[4000010];
int F[4000010],G[4000010];
inline int read()
{
    int ans=0,f=-1;
    char ch=getchar();
    while (ch<'0' || ch>'9')
    {
        if (ch=='-') f=-1;
        ch=getchar();
    }
    while (ch>='0' && ch<='9')
    {
        ans=ans*10+ch-'0';
        ch=getchar();
    }
    return ans;
}
int fpow(int x,int k)
{
    int ans=1;
    while (k)
    {
        if (k&1) ans=1LL*ans*x%Mod;
        x=1LL*x*x%Mod;
        k>>=1;
    }
    return ans;
}
void NTT(int *a,int inv)
{
    int lim=0;
    while ((1<<lim)<n) lim++;
    for (int i=0;i<n;i++)
    {
        int t=0;
        for (int j=0;j<lim;j++)
            if ((i>>j) & 1) t|=1<<(lim-j-1);
        if (i<t) swap(a[i],a[t]);
    }
    for (int l=2;l<=n;l*=2)
    {
        int m=l/2,p0=fpow(inv?invp:p,(Mod-1)/l);
        for (int *buf=a;buf!=a+n;buf+=l)
        {
            int pn=1;
            for (int i=0;i<m;i++)
            {
                int t=1LL*pn*buf[i+m]%Mod;
                buf[i+m]=(buf[i]-t+Mod)%Mod;
                buf[i]=(buf[i]+t)%Mod;
                pn=1LL*pn*p0%Mod;
            }
        }
    }
}
int main()
{
    lena=read(),lenb=read();
    lena++;lenb++;
    n=1;
    while (n<(lena+lenb)) n<<=1;    
    for (int i=0;i<lena;i++) F[i]=read();
    for (int i=0;i<lenb;i++) G[i]=read();
    NTT(F,0);NTT(G,0);
    for (int i=0;i<n;i++) F[i]=1LL*F[i]*G[i]%Mod;
    NTT(F,1);
    int invn=fpow(n,Mod-2);
    for (int i=0;i<lena+lenb-1;i++)
        printf("%d ",1LL*F[i]*invn%Mod);
    return 0;
}