1. 程式人生 > 實用技巧 >【數論】快速傅立葉變換 - 多項式相乘

【數論】快速傅立葉變換 - 多項式相乘

快速傅立葉變換

快速傅立葉變換(fast Fourier transform),即利用計算機計算離散傅立葉變換(DFT)的高效、快速計算方法的統稱,簡稱FFT。快速傅立葉變換是1965年由J.W.庫利和T.W.圖基提出的。採用這種演算法能使計算機計算離散傅立葉變換所需要的乘法次數大為減少。

—— 百度百科《快速傅立葉變換》



前置知識(主要部分)

係數表示法

對於一個多項式\(A(x)\),其係數表示法可表示為

\(A(x)=\sum_{i=0}^n{a_i*x^i}={a_0}+{a_1*x}+{a_2*x^2}+...+{a_n*x^n}\)

在係數表示法下,兩個多項式相乘的乘法規則為其中一個多項式的每一項與另一個多項式每一項相乘後相加

故該種方式求多項式乘法的時間複雜度為\(O(n^2)\)

點值表示法

將互不相同的 \(x_i\) 代入多項式中會得到互不相同的 \(y_i\)

獲得的這些點也就是多項式函式影象上的點座標 \((x_i,y_i)\)

可以證明,多項式可以由這些點唯一確定

獲得這些點後,只需要在兩個多項式對應的點值位置進行一次數值乘法即可獲得相乘後得到的多項式對應位置的值

\(y_i\) 的值滿足 \(y_i=\sum_{j=0}^na_j*x_i^j\)

若僅至此,即使是點值表示法計算多項式乘法時間複雜度也是\(O(n^2)\)的,但在這部分可以進行優化

複數

定義虛數單位\(i^2=-1\),令\(a\)\(b\)

為兩實數,形如\(a+bi\)的數即為複數



點值表示法選點的要求

FFT 要求\(n\)必須為\(2\)的正整數次冪

首先對於選點,點值表示法選點並不是隨意的,而是選擇複平面上具有特殊性質的\(n\)個點作為待代入的點

將複平面上以原點為圓心,\(1\)為半徑的圓\(n\)等分,作從原點出發終點為這些點的\(n\)個向量,設正幅度角(實數軸正方向逆時針旋轉得到的夾角)最小的向量為單位根\(w_n\),記作\(n\)次單位根

故這\(n\)個向量可表示為 \(w_n^1,w_n^2,w_n^3...w_n^n\),且\(w_n^n=w_n^0\)

這些向量值可以利用尤拉公式 \(w_n^k=\frac{2\pi}{n}cos\ k+\frac{2\pi}{n}i\ sin\ k\)

獲得



快速傅立葉變換

設多項式\(A(x)\)的係數為\((a_0,a_1,a_2...a_{n-1})\)

則多項式可表示為\(A(x)=a_0+a_1x+a_2x^2+a_3x^3+...+a_{n-1}x^{n-1}\)

將係數奇偶分開

\(A_1(x^2)=a_1+a_3x+a_5x^2+...+a_{n-1}x^{\frac{n}{2}-1}\)

\(A_2(x^2)=a_0+a_2x+a_4x^2+...+a_{n-2}x^{\frac{n}{2}-1}\)

可得\(A(x)=xA_1(x^2)+A_2(x^2)\)

\(x=w_n^k(k<\frac{n}{2})\)代入得\(A(w_n^k)=w_n^kA_1(w_n^{2k})+A_2(w_n^{2k})=w_n^kA_1(w_{\frac{n}{2}}^k)+A_2(w_{\frac{n}{2}}^k)\)

\(x=w_n^{k+\frac{n}{2}}(k<\frac{n}{2})\)代入得\(A(w_n^{k+\frac{n}{2}})=w_n^{k+\frac{n}{2}}A_1(w_n^{2({k+\frac{n}{2}})})+A_2(w_n^{2({k+\frac{n}{2}})})=A_2(w_n^{2k})-w_n^kA_1(w_n^{2k})\)

至此,可以發現兩式僅有一常數項不同,故在求第一個式子時也能直接將第二個式子求出

故此時待計算的部分縮小了一半

可知縮小後的式子仍然滿足上述性質,故該部分可由迭代來在\(O(logn)\)的時間內求出

基於此性質,快速傅立葉變換的總時間複雜度為\(O(nlogn)\)

主要部分實現方法圖示(圖是網上的,來自遠航之曲大佬):



對於快速傅立葉逆變換,本文不作闡述(推公式太麻煩),記用法即可



程式碼實現

首先對於複數,需要定義一個complex類來裝載

C++的STL中已經存在complex以供使用,但建議還是手動定義

struct cp
{
	double x,y;
	cp(double u=0,double v=0){x=u,y=v;}
	friend cp operator +(const cp &u,const cp &v){return cp(u.x+v.x,u.y+v.y);}
	friend cp operator -(const cp &u,const cp &v){return cp(u.x-v.x,u.y-v.y);}
	friend cp operator *(const cp &u,const cp &v){return cp(u.x*v.x-u.y*v.y,u.x*v.y+u.y*v.x);}
};

首先根據需求範圍\(n\)求出進行FFT操作的上界\(lim\) (需為\(2\)的正整數次冪)

int lim=1;

void initFFT(int n)
{
    int lg=0;
    while(lim<=n)
        lg++,lim<<=1;
}

※ 對於FFT的位置,還能在位置交換上通過預處理來進一步提升效率(蝴蝶迭代)

const int N=200050;
int rev[N],lim=1;
void initFFT(int n)
{
    int lg=0;
    while(lim<=n)
        lg++,lim<<=1;
    for(int i=0;i<lim;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1)); //由於i可以看做是i/2二進位制上的每一位左移一位得來,所以逆變換可以視為右移一位,同時處理下奇數
}

呼叫initFFT以獲得上界\(lim\)後,即可開始進行快速傅立葉變換(板子不是使用迭代法的)

void FFT(cp *a,int opr) //a為待進行FFT的複數陣列,opr為1即FFT,opr為-1即IFFT(快速傅立葉逆變換)
{
	for(int i=0;i<lim;i++)
        if(i<rev[i])
            swap(a[i],a[rev[i]]); //藉助蝴蝶迭代預處理得到的陣列以快速交換
	for(int md=1;md<lim;md<<=1)
	{
		cp rt=cp(cos(PI/md),opr*sin(PI/md)); //以尤拉公式快速獲得root
		for(int stp=md<<1,pos=0;pos<lim;pos+=stp)
		{
			cp w=cp(1,0);
			for(int i=0;i<md;i++,w=w*rt)
			{
				cp x=a[pos+i],y=w*a[pos+md+i];
				a[pos+i]=x+y;
				a[pos+md+i]=x-y;
			}
		}
	}
}

完整的板子

const int N=200050;

int lim=1,rev[N];

struct cp
{
	double x,y;
	cp(double u=0,double v=0){x=u,y=v;}
	friend cp operator +(const cp &u,const cp &v){return cp(u.x+v.x,u.y+v.y);}
	friend cp operator -(const cp &u,const cp &v){return cp(u.x-v.x,u.y-v.y);}
	friend cp operator *(const cp &u,const cp &v){return cp(u.x*v.x-u.y*v.y,u.x*v.y+u.y*v.x);}
}f[N],g[N];

void FFT(cp *a,int opr)
{
	for(int i=0;i<lim;i++)
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
	for(int md=1;md<lim;md<<=1)
	{
		cp rt=cp(cos(PI/md),opr*sin(PI/md));
		for(int stp=md<<1,pos=0;pos<lim;pos+=stp)
		{
			cp w=cp(1,0);
			for(int i=0;i<md;i++,w=w*rt)
			{
				cp x=a[pos+i],y=w*a[pos+md+i];
				a[pos+i]=x+y;
				a[pos+md+i]=x-y;
			}
		}
	}
}

void initFFT(int n)
{
    int lg=0;
    while(lim<=n)
        lg++,lim<<=1;
    for(int i=0;i<lim;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
}


模板題以及板子的運用

Luogu P3803 - 多項式乘法

P3803 【模板】多項式乘法(FFT)

單純的模板題,問的即為兩多項式的卷積(即多項式相乘後得到的新多項式)

解法:

將兩個多項式進行FFT,變換成點值表示法後直接在對應位置相乘,最後進行IFFT即可

#include<bits/stdc++.h>
using namespace std;

const int N=5000050;
const double PI=acos(-1.0);

int lim=1,rev[N];
struct cp
{
	double x,y;
	cp(double u=0,double v=0){x=u,y=v;}
	friend cp operator +(const cp &u,const cp &v){return cp(u.x+v.x,u.y+v.y);}
	friend cp operator -(const cp &u,const cp &v){return cp(u.x-v.x,u.y-v.y);}
	friend cp operator *(const cp &u,const cp &v){return cp(u.x*v.x-u.y*v.y,u.x*v.y+u.y*v.x);}
}f[N],g[N];

void FFT(cp *a,int tp)
{
	for(int i=0;i<lim;i++)
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
	for(int md=1;md<lim;md<<=1)
	{
		cp rt=cp(cos(PI/md),tp*sin(PI/md));
		for(int stp=md<<1,pos=0;pos<lim;pos+=stp)
		{
			cp w=cp(1,0);
			for(int i=0;i<md;i++,w=w*rt)
			{
				cp x=a[pos+i],y=w*a[pos+md+i];
				a[pos+i]=x+y;
				a[pos+md+i]=x-y;
			}
		}
	}
}

void initFFT(int n)
{
    int lg=0;
    while(lim<=n)
        lg++,lim<<=1;
    for(int i=0;i<lim;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
}

int main()
{
    int n,m;
    scanf("%d%d",&n,&m);
    initFFT(n+m);
    for(int i=0;i<=n;i++)
        cin>>f[i].x; //存在複數的實部中
    for(int i=0;i<=m;i++)
        cin>>g[i].x;
    FFT(f,1);
    FFT(g,1);
    for(int i=0;i<=lim;i++)
        f[i]=f[i]*g[i]; //直接計算卷積
    FFT(f,-1); //逆變換
    for(int i=0;i<=n+m;i++)
        printf("%d ",(int)round(f[i].x/lim)); //逆變換後需要除以lim,注意四捨五入(精度問題)
    return 0;
}


Luogu P1919 - A*B Problem升級版

P1919 【模板】A*B Problem升級版(FFT快速傅立葉)

解法:

FFT的處理部分與上題類似

注意到一個十進位制數字可以看作是一個\(x=10\)的多項式

例如\(123=1*10^2+2*10+3\)

所以兩個數字的乘法即為兩個\(x=10\)的多項式相乘

注意最後逆變換後可能會導致某個位置的係數\(≥10\),所以注意進位處理即可

#include<bits/stdc++.h>
using namespace std;

const int N=4000050;
const double PI=acos(-1.0);

int lim=1,rev[N];
struct cp
{
	double x,y;
	cp(double u=0,double v=0){x=u,y=v;}
	friend cp operator +(const cp &u,const cp &v){return cp(u.x+v.x,u.y+v.y);}
	friend cp operator -(const cp &u,const cp &v){return cp(u.x-v.x,u.y-v.y);}
	friend cp operator *(const cp &u,const cp &v){return cp(u.x*v.x-u.y*v.y,u.x*v.y+u.y*v.x);}
}f[N],g[N];

void FFT(cp *a,int tp)
{
	for(int i=0;i<lim;i++)
        if(i<rev[i])
            swap(a[i],a[rev[i]]);
	for(int md=1;md<lim;md<<=1)
	{
		cp rt=cp(cos(PI/md),tp*sin(PI/md));
		for(int stp=md<<1,pos=0;pos<lim;pos+=stp)
		{
			cp w=cp(1,0);
			for(int i=0;i<md;i++,w=w*rt)
			{
				cp x=a[pos+i],y=w*a[pos+md+i];
				a[pos+i]=x+y;
				a[pos+md+i]=x-y;
			}
		}
	}
}

void initFFT(int n)
{
    int lg=0;
    while(lim<=n)
        lg++,lim<<=1;
    for(int i=0;i<lim;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(lg-1));
}

char s1[N],s2[N];
int ans[N];

int main()
{
    scanf("%s%s",s1,s2);
    initFFT(strlen(s1)+strlen(s2));
    for(int i=strlen(s1)-1,j=0;i>=0;i--,j++) //倒序存入,十進位制數字次數最低的位在最後
        f[j].x=s1[i]-'0';
    for(int i=strlen(s2)-1,j=0;i>=0;i--,j++)
        g[j].x=s2[i]-'0';
    FFT(f,1);
    FFT(g,1);
    for(int i=0;i<=lim;i++)
        f[i]=f[i]*g[i];
    FFT(f,-1);
    int tmp=0;
    for(int i=0;i<=lim;i++)
    {
        ans[i]+=(int)(f[i].x/lim+0.5);
        if(ans[i]>=10)
        {
            ans[i+1]+=ans[i]/10;
            ans[i]%=10;
            if(i==lim)
                lim++;
        }
    }
    while(ans[lim]==0&&lim>=1)
        lim--;
    while(lim>=0)
    {
        printf("%d",ans[lim]);
        lim--;
    }
    putchar('\n');
    
    return 0;
}


這裡是借鑑的部落格