1. 程式人生 > 其它 >多項式/卷積相關(板子

多項式/卷積相關(板子

完形填空:兄弟會()你,()會()你,(),(),只有()不會,不會就是()。

又將是一篇除了板子啥都沒有的博...

雖然可能在板子前胡一堆東西,但其實都是廢話。。

卷積與DFT

形如這樣的式子被成為卷積:

\[c_k=\sum_{i\circ j=k}a_ib_j \]

其中 \(\circ\) 為任意一種運算。當它為 \(+\) 時該卷積即為多項式卷積。不難發現 \(\left\{c\right\}\) 即為係數分別為 \(\left\{a\right\}\)\(\left\{b\right\}\) 的多項式相乘後得到多項式的係數。

直接計算卷積是 \(O(n^2)\) 的。考慮對三個序列進行變換,使變換後得到的 \(\hat a\)\(\hat b\)\(\hat c\)

滿足如下式子

\[\hat{c_i}=\hat{a_i}\hat{b_i} \]

也就是將卷積轉換為了對位相乘,這樣就可以 \(O(n)\) 完成卷積。最後再通過某種方式將變換映射回去即可。

一種方便的方法是離散傅立葉變換DFT。它將 \(n-1\) 次多項式 \(f\) 的係數向量 \(\vec a=(a_0,a_1,\ldots,a_{n-2},a_{n-1})\) 轉化為點值向量 \(\vec y=(f(\omega_n^0),f(\omega_n^1),\ldots,f(\omega_n^{n-2}),f(\omega_n^{n-1}))\) 。其中 \(\omega_n\)\(n\)

次單位根,即複數域內 \(x^n=1\) 的解。一般選用幅角最小的那個,即 \(\cos(\frac{2\pi}{n})+\sin(\frac{2\pi}{n})i\)

選擇單位根作為點值的橫座標是因為它很多優秀的性質。這裡不表不會。可以參考周指導の指導。(其實這篇博大部分東西都是這裡學的

FFT

利用複數域內單位根的性質加速離散傅立葉變換過程的演算法。

具體性質不會。可以跟周指導學習一波。

因為不大能理解,一開始背板在FFT部分屬實有些煎熬。這裡加一點沒有邏輯的註釋。

void fft(complex *a,int n){
    for(int i=0;i<n;i++)
        if(r[i]>i) swap(a[i],a[r[i]]); //提出要迭代的序列,先處理r[i]
    for(int t=n>>1,d=1;d<n;t>>=1,d<<=1) //d為迭代區間長度的一半,t為計算單位根次數的係數
        for(int i=0;i<n;i+=(d<<1)) //i為當前迭代區間的左端點
            for(int j=0;j<d;j++){ //j為當前處理的位置
                complex tmp=w[t*j]*a[i+j+d];
                a[i+j+d]=a[i+j]-tmp;
                a[i+j]=a[i+j]+tmp;
            }
}

不能幫助理解,但能幫助記憶

洛谷P3803[模板]多項式乘法
#include<bits/stdc++.h>
using namespace std;

namespace IO{
    typedef long long LL;
    typedef double DB;
    int read(){
        int x=0,f=0; char ch=getchar();
        while(ch>'9'||ch<'0'){ f|=(ch=='-'); ch=getchar(); }
        while(ch>='0'&&ch<='9'){ x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); }
        return f?-x:x;
    } char output[50];
    void write(int x,char sp){
        int len=0;
        if(x<0) putchar('-'), x=-x;
        do{ output[len++]=x%10+'0'; x/=10; }while(x);
        for(int i=len-1;~i;i--) putchar(output[i]); putchar(sp);
    }
    void ckmax(int& x,int y){ x=x>y?x:y; }
    void ckmin(int& x,int y){ x=x<y?x:y; }
} using namespace IO;

const int NN=1<<22;
int dega,degb;

namespace Poly_Calculation{
    const DB PI=acos(-1.0);
    struct complex{
        DB r,i;
        complex(){ r=i=0; }
        complex(DB x,DB y){ r=x; i=y; }
        complex operator+(const complex& t)const{ return complex(r+t.r,i+t.i); }
        complex operator-(const complex& t)const{ return complex(r-t.r,i-t.i); }
        complex operator*(const complex& t)const{ return complex(r*t.r-i*t.i,r*t.i+i*t.r); }
    }a[NN],b[NN],w[NN];
    int l,len,r[NN];
    void prework(){
        for(len=1;len<=dega+degb;len<<=1) ++l;
        for(int i=0;i<len;i++){
            r[i]=(r[i>>1]>>1)|((i&1)<<l-1);
            w[i]=complex(cos(2.0*i*PI/len),sin(2.0*i*PI/len));
        }
    }
    void fft(complex *a,int n){
        for(int i=0;i<n;i++)
            if(r[i]>i) swap(a[i],a[r[i]]);
        for(int t=n>>1,d=1;d<n;t>>=1,d<<=1)
            for(int i=0;i<n;i+=(d<<1))
                for(int j=0;j<d;j++){
                    complex tmp=w[t*j]*a[i+j+d];
                    a[i+j+d]=a[i+j]-tmp;
                    a[i+j]=a[i+j]+tmp;
                }
    }
    void poly_mul(complex *a,complex *b){
        fft(a,len); fft(b,len);
        for(int i=0;i<len;i++)
            a[i]=a[i]*b[i], w[i].i=-w[i].i;
        fft(a,len);
        for(int i=0;i<len;i++) a[i].r=a[i].r/len;
    }
} using namespace Poly_Calculation;

signed main(){
    dega=read()+1; degb=read()+1;
    for(int i=0;i<dega;i++) a[i].r=read();
    for(int i=0;i<degb;i++) b[i].r=read();
    prework(); poly_mul(a,b);
    --dega; --degb;
    for(int i=0;i<=dega+degb;i++)
        write(int(a[i].r+0.5),' ');
    return puts(""),0;
}

NTT

雖然FFT已經可以完成快速多項式卷積的工作,但它也存在一些弊端。如精度問題,以及涉及取模時無法方便運算等。

這時就需要使用NTT。大體來講,NTT與FFT的原理是完全一致的,只不過使用模意義下的原根(的整數次冪)代替了複數域內的單位根,避免了浮點數計算。

經過一些推導,令模數意義下原根為 \(g\) ,那麼原來的 \(w_n\) 就可以替換為 \(g^{\frac{mod-1}{n}}\)

因為多項式長度都取 \(2\) 的整數次冪,所以要求 \(g\) 的指數為整數,對模數有一些要求。

兩個常見的模數: \(998244353\)\(1004535809\) ,原根都為 \(3\)

板子方面幾乎跟FFT一模一樣啊。。

洛谷P3803[模板]多項式乘法
#include<bits/stdc++.h>
using namespace std;

namespace IO{
    typedef long long LL;
    typedef double DB;
    LL read(){
        LL x=0,f=0; char ch=getchar();
        while(ch>'9'||ch<'0'){ f|=(ch=='-'); ch=getchar(); }
        while(ch>='0'&&ch<='9'){ x=(x<<1)+(x<<3)+(ch^48); ch=getchar(); }
        return f?-x:x;
    } char output[50];
    void write(LL x,char sp){
        int len=0;
        if(x<0) putchar('-'), x=-x;
        do{ output[len++]=x%10+'0'; x/=10; }while(x);
        for(int i=len-1;~i;i--) putchar(output[i]); putchar(sp);
    }
    void ckmax(int& x,int y){ x=x>y?x:y; }
    void ckmin(int& x,int y){ x=x<y?x:y; }
} using namespace IO;

const int NN=1<<22,mod=998244353;
int dega,degb;
LL qpow(LL a,LL b=mod-2,LL res=1){
    for(;b;b>>=1,a=a*a%mod)
        if(b&1) res=res*a%mod;
    return res;
}

namespace Poly_Calculation{
    LL l,len,inv,g[NN],r[NN],a[NN],b[NN];
    void prework(){
        for(len=1;len<=dega+degb;len<<=1) ++l;
        for(int i=0;i<len;i++)
            r[i]=(r[i>>1]>>1)|((i&1)<<l-1);
        g[0]=1; g[1]=qpow(3,(mod-1)/len);
        for(int i=2;i<len;i++) g[i]=g[i-1]*g[1]%mod;
    }
    void ntt(LL *a,int n){
        for(int i=0;i<n;i++)
            if(r[i]>i) swap(a[r[i]],a[i]);
        for(int t=n>>1,d=1;d<n;d<<=1,t>>=1)
            for(int i=0;i<n;i+=(d<<1))
                for(int j=0;j<d;j++){
                    LL tmp=g[t*j]*a[i+j+d]%mod;
                    a[i+j+d]=(a[i+j]-tmp+mod)%mod;
                    a[i+j]=(a[i+j]+tmp)%mod;
                }
    }
    void poly_mul(LL *a,LL *b){
        ntt(a,len); ntt(b,len);
        for(int i=0;i<len;i++)
            a[i]=a[i]*b[i]%mod, g[i]=qpow(g[i]);
        ntt(a,len);
        inv=qpow(len);
        for(int i=0;i<len;i++) a[i]=a[i]*inv%mod;
    }
} using namespace Poly_Calculation;

signed main(){
    dega=read()+1; degb=read()+1;
    for(int i=0;i<dega;i++) a[i]=read();
    for(int i=0;i<degb;i++) b[i]=read();
    prework(); poly_mul(a,b);
    --dega; --degb;
    for(int i=0;i<=dega+degb;i++)
        write(a[i],' ');
    return puts(""),0;
}