1. 程式人生 > 其它 >FFT&NTT

FFT&NTT

  • 快速傅立葉變換(Fast-Fourier-Transform)

已知多項式$A(x)=\sum _{i=0}^{N} a_ix^i,B(x)=\sum _{i=0}^{M} b_ix^i$求$A(x)*B(x)$.

  顯然看出可以列舉兩個多項式的係數,依次算出,時間$O(nm)$.

  太慢了!!怎麼辦?利用一個奇妙的東西:FFT

  • 多項式的點值表示法

    對於一個多項式$A(x)=\sum _{i=0}^{N} a_ix^i$,可以取$N$個不同的$x$值,求得$N$個多項式值。將其作為點,即$(x_i,A(x_i))$

  FFT的大致思路就是

  1.   將多項式化為點值形式。
  2.         將點值相乘。即算出每一個$C(x)=A(x)*B(x)$
  3.        將新的點值轉化回多項式形式。

前置芝士:向量與複數

  • 向量

  向量,即有方向的量,在平面直角座標系上可以用$(a,b)$表示。

  圖形上即為由原點指向點$(a,b)$的有向線段。

  向量的模長為$\sqrt{a^2+b^2}$

  向量的幅角為向量逆時針旋轉至與x軸正半軸重合時旋轉的角度。

  向量的加減法滿足平行四邊形法則,即$\overrightarrow{m}(x_1,y_1)\pm \overrightarrow{n}(x_2,y_2) = \overrightarrow{p}(x_1 \pm x_2,y_1 \pm y_2)$

  

  • 複數

  定義虛數單位$i$ 滿足$i^2=-1$,複數域$I$,形如$a+bi,(a,b\in \mathbb{R})$的數叫做複數。

  複數$a+bi$可以在座標系中表示為$(a,b)$的向量。

  同時複數的加減法滿足向量的加減法,模長與幅角的定義也與向量相同。

  若複數$z$的模長為$|z|$,幅角為$\theta$,根據座標系則有

$z=|z|cos \theta +i|z|sin \theta$

  複數的乘法:

$(x_1+y_1 i)*(x_2+y_2 i)=x_1 x_2+x_2 y_1 i+x_1 y_2 i - y_1 y_2 $

$=(x_1 x_2 - y_1 y_2)+(x_1 y_2+x_2 y_1)i$

  並且兩個複數相乘遵循一個規律:模長相乘,幅角相加。

  • 複數的單位根

   在座標系中做一個單位圓,將單位圓等分成$n$份的$n$個向量所對應的複數稱為$n$次單位根

  幅角最小的記為$\omega _n$,而幅角是$\omega _n$的$k$倍的單位根為$\omega _n^k$.

     

  8次單位根↑

  由於我們只需要$2^n$次單位根,所以以下單位根均為$2$的冪次單位根。

  單位根的性質:

  1.$\omega _n^{kn} =1 (k\in \mathbb{Z}) , \omega _n^k * \omega _n^j = \omega _n^{k+j}$ 

  根據複數乘法,很明顯。

  2.$\omega _n^k =cos 2\pi \frac{k}{n} +isin 2\pi \frac{k}{n}$ 

  複數的三角表示法。

  3.$\omega _{tn}^{tk} =\omega _n^k$

  因為$cos 2\pi \frac{tk}{tn}+isin 2\pi \frac{tk}{tn}=cos 2\pi \frac{k}{n}+isin 2\pi \frac{k}{n}$.

  4.$\omega _{n}^{\frac {n}{2}}=-1$

 $\omega _{n}^{\frac {n}{2}}=cos 2\pi \frac{1}{2}+isin 2\pi \frac{1}{2}$

$=cos \pi +isin \pi$

$=-1$

快速傅立葉變換O(nlogn)

   設一個多項式$A(x)$的係數為$(a_0,a_1,a_2...a_{n-1})$.

   首先在$A(x)$後面補係數為$0$的項,直到$n$為$2$的冪數,方便接下來運算。

   我們可以將所有的$\omega _n^k k \in [0,n-1]$代入求得$n$個點值,並可以做出優化。

$A(x)=a_0+a_1x+a_2x^2+...+a_{n-1}x^{n-1}$

$=(a_0+a_2x^2+...+a_{n-2}x^{n-2})+x(a_1+a_3x^2+...+a_{n-1}x^{n-1})$

      令$A_1(x)=a_0+a_2x^2+...+a_{n-2}x^{n-2},A_2(x)=a_1+a_3x^2+...+a_{n-1}x^{n-1}$.

$A(x)=A_1(x^2)+xA_2(x^2)$

      將$x=\omega _n^k (0 \leq k<\frac{n}{2})$代入上式

$A(\omega _n^k)=A_1(\omega _n^{2k})+\omega _n^kA_2(\omega _n^{2k})$

$=A_1(\omega _{\frac{n}{2}}^{k})+\omega _n^kA_2(\omega _{\frac{n}{2}}^{k})$

    同理將$x=\omega _n^{k+\frac{n}{2}} (0 \leq k<\frac{n}{2})$代入。

$A(\omega _n^{k+\frac{n}{2}})=A_1(\omega _n^{2k+n})-\omega _n^kA_2(\omega _n^{2k+n})$

$=A_1(\omega _n^{2k})-\omega _n^kA_2(\omega _n^{2k})$

$=A_1(\omega _{\frac{n}{2}}^{k})-\omega _n^kA_2(\omega _{\frac{n}{2}}^{k})$

  之後我們發現只要求出$A_1(\omega _{\frac{n}{2}}^{k})和A_2(\omega _{\frac{n}{2}}^{k})$就可以算出兩個點值。而他們可以遞迴去求,並且剛好由$n$次變為了$\frac{n}{2}$次,時間複雜度類似線段樹$O(nlogn)$.

  然後求出兩個多項式的所有點值之後將他們分別相乘,得出新多項式的$N+M+1$個點值,這一步是$O(n)$的。

快速傅立葉逆變換O(nlogn)

  接下來我們只需要把點值形式轉化為多項式形式即可。

  設多項式$A(x)(a_0,a_1,a_2...a_{n-1})$的點值表示為$(y_0,y_1,y_2...y_{n-1})$

  多項式$D(x)=\sum _{i=0}^{n-1} y_ix^i$,$D(x)$在$(\omega _n^0,\omega _n^{-1},\omega _n^{-2}...\omega _n^{-(n-1)})$的點值表示為$(c_0,c_1,c_2...c_{n-1})$

  則有

$c_k=\sum _{i=0}^{n-1} y_i(\omega _n^{-k})^{i}$

$=\sum _{i=0}^{n-1} \sum _{j=0}^{n-1} a_j\omega _n^{ij} \omega _n^{-ik}$

$=\sum _{i=0}^{n-1} \sum _{j=0}^{n-1} a_j(\omega _n^{j-k})^i$

$=\sum _{j=0}^{n-1} \sum _{i=0}^{n-1} a_j(\omega _n^{j-k})^i$

$=\sum _{j=0}^{n-1} a_j\sum _{i=0}^{n-1} (\omega _n^{j-k})^i$

  令$T(x)=\sum _{i=0}^{n-1} x^i$,則有

$T(\omega _n^{t})=1+\omega _n^t+(\omega _n^t)^2+...+(\omega _n^t)^{n-1}$  A式

$A*\omega _n^{t}$得:

$\omega _n^{t}T(\omega _n^{t})=\omega _n^t+(\omega _n^t)^2+...+(\omega _n^t)^{n}$ B式 

$B-A$得

$(\omega _n^{t}-1)T(\omega _n^{t})=(\omega _n^t)^{n}-1$

$(\omega _n^{t}-1)T(\omega _n^{t})=(\omega _n^n)^{t}-1=1-1=0$

    所以當$\omega _n^{t}-1!=0$時$T(\omega _n^{t})=0$

    當$\omega _n^{t}-1=0$時可以得到$\omega _n^{t}=1,t=0$.

    則$T(\omega _n^{t})=T(1)=\sum _{i=0}^{n-1} 1=n$.

    有了這個結論後我們來看這個式子:

$c_k=\sum _{j=0}^{n-1} a_j\sum _{i=0}^{n-1} (\omega _n^{j-k})^i$

$=\sum _{j=0}^{n-1} a_jT(\omega _n^{j-k})$

   當且僅當$j=k$時有值,即

$c_k=a_jn$

$a_j=\frac{a_k}{n}$

    所以我們只需要求出多項式$D(x)$在$(\omega _n^0,\omega _n^{-1},\omega _n^{-2}...\omega _n^{-(n-1)})$的點值表示即可算出$a_i$.

遞迴版:

const db pi=acos(-1);
class cplx{
public:
    db x,y;
    cplx(){x=y=0;}
    cplx(const db a,const db b){x=a,y=b;}
    friend cplx operator +(const cplx a,const cplx b){return cplx(a.x+b.x,a.y+b.y);}
    friend cplx operator -(const cplx a,const cplx b){return cplx(a.x-b.x,a.y-b.y);}
    friend cplx operator *(const cplx a,const cplx b){return cplx(a.x*b.x-a.y*b.y,a.y*b.x+a.x*b.y);}
}a[maxn],b[maxn];
int N,M,lim=1;
void fft(int lm,cplx *a,int op){
    if(lm==1) return;
    cplx a1[lm>>1],a2[lm>>1];
    for(int i=0;i<=lm;i+=2)
        a1[i>>1]=a[i],a2[i>>1]=a[i+1];
    fft(lm>>1,a1,op);
    fft(lm>>1,a2,op);
    cplx w1=cplx(cos(2*pi/lm),op*sin(2*pi/lm)),wk=cplx(1,0);
    for(int i=0;i<(lm>>1);i++,wk=wk*w1){
        cplx b=wk*a2[i];
        a[i]=a1[i]+b;
        a[i+(lm>>1)]=a1[i]-b;
    }
}
int MAIN(){
    cin>>N>>M;
    for(int i=0;i<=N;i++) scanf("%lf",&a[i].x);
    for(int i=0;i<=M;i++) scanf("%lf",&b[i].x);
    while(lim<=N+M) lim<<=1;
    fft(lim,a,1);
    fft(lim,b,1);
    for(int i=0;i<=lim;i++) a[i]=a[i]*b[i];
    fft(lim,a,-1);
    for(int i=0;i<=N+M;i++) prt(a[i].x/lim);
    return 0;
}

 但是我們發現,這種寫法需要很多次複製陣列,既耗記憶體也耗空間。

迭代優化:

  我們寫出$n=8$時的遞迴詳細:

  我們發現一個神奇的性質:遞迴到最底層時實際的值為原下標的二進位制翻轉!!(具體證明見文末)

    於是我們沒有必要再進行遞迴,只需要將陣列調換至最底層的狀態然後一層一層往回的迭代即可!

  

const db pi=acos(-1);
class cplx{
public:
    db x,y;
    cplx(){x=y=0;}
    cplx(const db a,const db b){x=a,y=b;}
    friend cplx operator +(const cplx a,const cplx b){return cplx(a.x+b.x,a.y+b.y);}
    friend cplx operator -(const cplx a,const cplx b){return cplx(a.x-b.x,a.y-b.y);}
    friend cplx operator *(const cplx a,const cplx b){return cplx(a.x*b.x-a.y*b.y,a.y*b.x+a.x*b.y);}
}a[maxn],b[maxn];
int N,M,lim=1,tr[maxn],l=0;
void fft(cplx *a,int op){
    for(int i=0;i<lim;i++) if(i<tr[i])swap(a[i],a[tr[i]]);
    for(int m=1;m<lim;m<<=1){
        cplx w1(cos(pi/m),op*sin(pi/m));
        int len=m<<1;
        for(int i=0;i<lim;i+=len){
            cplx wk(1,0);
            for(int k=0;k<m;k++,wk=wk*w1){
                cplx a1=a[i+k],a2=wk*a[i+m+k];
                a[i+k]=a1+a2;
                a[i+m+k]=a1-a2;
            }
        }
    }
}
int MAIN(){
    cin>>N>>M;
    for(int i=0;i<=N;i++) scanf("%lf",&a[i].x);
    for(int i=0;i<=M;i++) scanf("%lf",&b[i].x);
    while(lim<=N+M) lim<<=1,++l;
    for(int i=1;i<lim;i++){
        tr[i]=(tr[i>>1]>>1)|((i&1)?(1<<(l-1)):0);
    }
    fft(a,1);
    fft(b,1);
    for(int i=0;i<=lim;i++) a[i]=a[i]*b[i];
    fft(a,-1);
    for(int i=0;i<=N+M;i++) prt(a[i].x/lim);
    return 0;
}

總時間複雜度為O(nlogn).

  • 快速數論變換(Number-Theoretic-Transform)

  我們發現,FFT中因為要用到三角函式以及浮點數的運算,精度得不到保障,並且複數的常數較大,我們可以進行優化:

  引入概念:

  • 原根

  設m是正整數,a是整數,若a模m的階等於φ(m),則稱a為模m的一個原根。(其中φ(m)表示m的尤拉函式)

  先不用管原根的定義,扔出一個結論(設$g$為$P$的原根):

$\omega _n \equiv g^{\frac{P-1}{n}} (mod P)$

  原根滿足這樣的性質:

$g^i != g^j (mod P,i!=j)$

  並且根據費馬小定理:

$\omega _n^n \equiv g^{P-1} =1 (mod P)$

  所以我們知道原根的性質與單位根類似,可以用$g^{\frac {P-1}{n}}$來代替$\omega _n$.

  如何求質數$P$的原根?

  首先需要知道滿足$a^n \equiv 1 (mod P)$的最小$n$值一定滿足$n|P-1$.

  質因數分解$P-1=\prod p_i^{a_i}$

  那麼如果有$m|P-1,n|P-1,m|n,a^m \equiv 1(mod P)$,則有$a^n \equiv 1(mod P)$

  所以要驗證一個數$t$是不是原根,要列舉每一個$p_i$,均滿足$t^{\frac{P-1}{p_i}}!=1(mod P)$成立,則$t$是原根。

  $P$一般取$998244353$,他的原根是$3$.

const int maxn=(1<<21)+5,mod=998244353,g=3;
int qp(int x,int y){
    long long ans=1;
    while(y){
        if(y&1) ans=ans*x%mod;
        x=((long long)x*x)%mod;
        y>>=1;
    }
    return (int)ans;
}
const int ginv=qp(g,mod-2);
int a[maxn],b[maxn];
int N,M,lim=1,tr[maxn],l=0;
void ntt(int *a,int op){
    for(int i=0;i<lim;i++) if(i<tr[i])swap(a[i],a[tr[i]]);
    for(int m=1;m<lim;m<<=1){
        int len=m<<1;
        int g1=qp(op==1?g:ginv,(mod-1)/len);
        for(int i=0;i<lim;i+=len){
            int gk=1;
            for(int k=0;k<m;k++,gk=(long long)gk*g1%mod){
                int a1=a[i+k],a2=(long long)gk*a[i+m+k]%mod;
                a[i+k]=(a1+a2)%mod;
                a[i+m+k]=(a1-a2+mod)%mod;
            }
        }
    }
}
int MAIN(){
    cin>>N>>M;
    for(int i=0;i<=N;i++) scanf("%d",&a[i]);
    for(int i=0;i<=M;i++) scanf("%d",&b[i]);
    while(lim<=N+M) lim<<=1,++l;
    for(int i=1;i<lim;i++){
        tr[i]=(tr[i>>1]>>1)|((i&1)?(1<<(l-1)):0);
    }
    ntt(a,1);
    ntt(b,1);
    for(int i=0;i<=lim;i++) a[i]=((long long)a[i]*b[i])%mod;
    ntt(a,-1);
    int ny=qp(lim,mod-2);
    for(int i=0;i<=N+M;i++) printf("%lld ",(long long)a[i]*ny%mod);
    return 0;
}

保證了無精度誤差,並且跑的飛快,大概是FFT速度2倍。

關於二進位制翻轉的證明:

  顯然,在前$i$層對應著原下標的前$i$位,向左即為$0$,向右即為$1$.

  而前$i$層對應實際係數下標的後$i$位,向左即為$0$,向右即為$1$,因為選出奇數代表選擇$1$,偶數代表$0$

  所以對於任意一層,原下標的前$i$位均相等,實際係數下標的後$i$位均相等,且兩者有著翻轉關係。

  在最底層即為原下標是實際下標的翻轉。