FFT & NTT 學習筆記
機房人均多項式帶師,就我啥都不會!
所以來填坑了qwq
FFT
前置知識
複數:\(z=a+bi\) ,其中 \(i=\sqrt{-1}\) . 複數可以表示在 複平面 上,\(z\) 橫座標為 \(a\) ,縱座標為 \(b\) . 簡單瞭解複數 “幅角”、“模長”的概念。
基礎三角比。知道 \(\sin,\cos,\tan\) .
卷積概念:形如 \(C[k]=\sum_{i\oplus j=k}A[i]B[j]\) 的式子成為卷積,其中 \(\oplus\) 為運算子。多項式乘法就是加法卷積。
DFT & IDFT 思想
\(F(x)=a_nx^n+a_{n-1}x^{n-1}+\dots +a_0\)
注意到 \(n+1\) 個點及其對應的 \(F(x)\) 可以唯一確定一個 \(n\) 次多項式,所以又有了 點值表示法 ,即用 \(n+1\) 個有序數對來表示一個多項式。
設 \(F(x)\) 可以表示為數列 \(X=\{x_0,x_1,\dots,x_n\}\) ,\(G(x)\) 表示為 \(Y=\{y_0,y_1,\dots y_n\}\) (這裡省略了點的橫座標,預設兩個數列對應同一組橫座標),那麼不難想到, \(F(x)*G(x)\) 可以表示為 \(\{x_0y_0,x_1y_1,\dots ,x_ny_n\}\)
然而我們需要的是係數係數表示法。所以不難想到實現卷積的流程:係數轉點值(求值) $\to $ 點值相乘 \(\to\) 點值轉系數(插值)。
在 FFT 中,第一步叫做 DFT ,最後一步叫做 IDFT ( DFT 逆運算 )。
單位根
前面已經提到了有複平面這個東西。現在我們在上面以原點為圓心,畫一個半徑為 \(1\)
如圖,這是 \(8\) 等分的結果。
現在 單位根 出現了:以原點為起點,上面得到的 \(n\) 等分點為終點,作 \(n\) 個向量,設幅角為正且最小的向量對應的複數為 \(\omega_n\) ,稱為 \(n\) 次單位根。根據複數乘法不難得到,其餘 \(n-1\) 個向量對應複數依次為 \(\omega_n^2,\omega_n^3,\dots ,\omega_n^n\) . 特別地,\(\omega_n^0=\omega_n^n=1\) .(即 \(x\) 軸正方向的那個向量) 比如上圖中,\(\vec{AB}\) 就代表了 \(\omega_8\) .
這裡有一些相關性質:
- \(\omega_n^k=(\omega_n^1)^k\) ,乘一個 \(\omega_n^1\) 的幾何意義就是把幅角逆時針轉動 \(\dfrac{1}{n}\) 個周角。
- \(\omega_n^j\times \omega_n^k=\omega^{j+k},\omega_{2n}^{2k}=\omega_n^k\)
- 如果 \(n\) 為偶數,那麼有 \(\omega_n^{k+n/2}=-\omega_n^k\) ,相當於轉了半個周角,自然是取反。
FFT
現在來考慮一個 \(n-1\) 次多項式 \(F(x)\) ,每一項按照下標奇偶分開:(這裡設 \(n\) 是 \(2\) 的冪次,不是可以在高次補一些係數為 \(0\) 的項)
\[F(x)=(a_0+a_2x^2+\dots +a_{n-2}x^{n-2})+(a_1x+\dots +a_{n-1}x^{n-1}) \]為了方便,記
\[FL(x)=a_0+a_2x^2+\dots+a_{n-2}x^{n/2-1},FR(x)=a_1+a_3x+\dots+a_{n-1}x^{n/2-1} \]那麼有
\[F(x)=FL(x^2)+xFR(x^2) \]現在把 \(\omega_n^k\) 代入:
- \(k<n/2\) ,代入 \(\omega_n^k\)
- \(k<n/2\) ,代入 \(\omega_n^{k+n/2}\)
於是這兩個式子只有 \(FR\) 前面的符號不同。所以如果 \(FL(x),FR(x)\) 在 \(\omega_{n/2}^0,\cdots,\omega_{n/2}^{n/2-1}\) 的點值表示,就能 \(\mathcal{O}(n)\) 求出 \(F(x)\) 在 \(\omega_n^0,\cdots ,\omega_n^{n-1}\) 的點值表示。顯然,這樣的過程可以直接分治實現。
上面已經實現了 DFT ,現在來看 IDFT ,即 DFT 的逆運算。
有結論:把 DFT 中的 \(\omega_n^1\) 換成 \(\omega_n{-1}\) ,用卷積之後得到的多項式放進去做一遍,然後除以 \(n\) 即可。具體證明參見文末參考文獻。
於是這樣 DFT 和 IDFT 就能一個函式實現了。
具體實現
預處理單位根 & 合併
如果正常每次算一遍單位根,複雜度是 \(\mathcal{O}(n\log n)\) 的,預處理出單位根就是 \(\mathcal{O}(n)\) ,能減小很多常數。
首先用 \(\left(\cos\left(\dfrac{2\pi}{n}\right),\sin\left(\dfrac{2\pi}{n}\right)\right)\) 求出 \(\omega_n^1\) ,其餘直接複數乘上去即可。複數手寫結構體就好,當然不怕常數也能用 STL 。
一種更優的寫法看程式碼實現。
合併陣列時,最簡單的方法就是開一個臨時陣列,用當前 \(f\) 往那裡貢獻,最後再 copy 一遍。這樣顯然不優良。
嘗試改變賦值順序,類似 DP 一樣分析 \(f\) 貢獻時哪些會改變,只要保證當前往結果貢獻的這部分還是這一輪之前(也就是沒被這一輪操作覆蓋)的結果就可以了。
蝴蝶變換
觀察我們奇偶分治的過程,發現最後的序列下標對應原序列下標二進位制翻轉之後的結果。那麼我們並不需要每次都把一個數組拷來拷去,還按照奇偶分成兩個,可以預處理出二進位制翻轉的結果,直接賦值。這樣通過遞推實現,複雜度是 \(\mathcal{O}(n)\) 的。
for ( int i=0; i<lim; i++ )
rev[i]=(rev[i>>1]>>1) | ((i&1) ? lim>>1 : 0);
注意後面的 rev[i>>1]>>1
,記住當前是以 \(lim-1\) (某個二進位制下全 \(1\) 的東西)為最高位的, rev[i>>1]
的末尾會多出來一位,要把這一位去掉。比如 rev[1]=100
,而 rev[2]=(100>>1)|0
,010
先右移取了前兩位,但是這是 實際的值 ,正常的沒有翻轉的 \(1\) 是會在最開頭再多一個 \(0\) 的,所以翻轉之後要把這個 \(0\) 扔掉。不懂就手模
三次變兩次
令 \(P(x)=F(x)+G(x)i\) ,那麼 \(P(x)^2=F(x)^2-G(x)^2+2F(x)G(x)i\) ,所求就是虛部的一半。所以只需要兩次 FFT 即可。
Warning:值域相差過大會卡精度,可以數乘到相同值域再做。
模板
題目連結 這題輸入輸出量比較大,所以快讀快寫能起到很大的優化。這個板子大概是平均 930ms
左右。
//Author:RingweEH
//P3803 【模板】多項式乘法(FFT)
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin),p1==p2)?EOF:*p1++)
char buf[1<<20],*p1=buf,*p2=buf;
int read()
{
int x=0; char ch=getchar();
while ( ch>'9' || ch<'0' ) ch=getchar();
while ( ch<='9' && ch>='0' ) x=x*10+ch-48,ch=getchar();
return x;
}
void write( int x ) { if(x<0) putchar(45); if ( x>9 ) write(x/10); putchar(x%10+48); }
const int N=2100000;
const db PI=acos(-1.0)*2;
int n,m,rev[N];
struct Complex
{
db x,y;
Complex( db _x=0,db _y=0 ) : x(_x),y(_y) {}
}F[N];//,G[N];
Complex operator + ( Complex t1,Complex t2 ) { return Complex(t1.x+t2.x,t1.y+t2.y); }
Complex operator - ( Complex t1,Complex t2 ) { return Complex(t1.x-t2.x,t1.y-t2.y); }
Complex operator * ( Complex t1,Complex t2 ) { return Complex(t1.x*t2.x-t1.y*t2.y,t1.x*t2.y+t1.y*t2.x); }
void FFT( Complex *f,bool fl )
{
int i,k,len,l;
for ( i=0; i<n; i++ )
if ( i<rev[i] ) swap( f[i],f[rev[i]] );
for ( k=2,len; k<=n; k<<=1 )
{
len=k>>1; Complex w(cos(PI/k),sin(PI/k));
if ( !fl ) w.y*=-1;
for ( i=0; i<n; i+=k )
{
Complex buf(1,0);
for ( l=i; l<i+len; l++ )
{
Complex tmp=buf*f[len+l];
f[len+l]=f[l]-tmp; f[l]=f[l]+tmp;
buf=buf*w;
}
}
}
}
int main()
{
n=read(); m=read(); int i;
for ( i=0; i<=n; i++ ) F[i].x=read();
for ( i=0; i<=m; i++ ) F[i].y=read();
for ( m+=n,n=1; n<=m; n<<=1 );
for ( i=0; i<n; i++ ) rev[i]=(rev[i>>1]>>1)|((i&1) ? n>>1 : 0);
FFT( F,1 ); //FFT( G,1 ); //DFT
for ( i=0; i<n; i++ ) F[i]=F[i]*F[i];
FFT( F,0 ); //IDFT
for ( i=0; i<=m; i++ )
write((int)(F[i].y/n/2+0.49)),putchar(32);
//printf( "%d ",(int)(F[i].y/n/2+0.49) ); //printf( "%d ",(int)(F[i].x/n+0.49) );
return 0;
}
NTT
前置知識
如果 \(a,p\) 互質且 \(p>1\) ,對於 \(a^n\equiv 1(\bmod p)\) 最小的 \(n\) ,稱為 \(a\) 模 \(p\) 的階,記為 \(\delta_p(a)\) .
原根
FFT 依賴單位根,所以必須採用浮點數,引發精度問題。NTT 就是 FFT 在模意義下的替代品。這裡,原根代替了單位根。
先考慮單位根有哪些性質:
- \(\omega_n^k=(\omega_n^1)^k\)
- \(\omega_n^{0\sim n-1}\) 互不相同
- \(\omega_n^k=\omega_n^{k\bmod n}\) 。前三條等價於 \(\omega_n^1\) 在模意義下階恰為 \(n\) .
- \(\omega_{2n}^{2k}=\omega_n^k\)
原根的定義:對於一個素數 \(p\) ,如果 \(g\) 的階達到 \(p-1\) 的上界,稱 \(g\) 為原根。
注意到 \(g^k\) 的階恰為 \((p-1)/\gcd(k,p-1)\) ,這個數仍然是 \(p-1\) 的約數。所以,當 \(n\nmid (p-1)\) 時,找不到階恰為 \(n\) 的數。
當 \(n\mid (p-1)\) 時,\(g^{(p-1)/n}\) 的階為 \(n\) ,且不難發現也滿足最後一個條件。
由於演算法中 \(n\) 往往是 \(2\) 的冪次,我們只需要構造一個質數 \(p\) 使得 \(p-1\) 包含大量因子 \(2\) 即可。
常用原根:詳細版本 不用原根的 trick
\(p\) | \(g\) |
---|---|
\(998244353=119\cdot 2^{23}+1\) | \(3\) |
\(2281701377=17\cdot 2^{27}+1\) | \(3\) |
\(1004535809=479\cdot 2^{21}+1\) | \(3\) |
具體實現
為什麼沒有演算法講解?因為沒有本質區別QWQ
額外技巧:
- 預處理原根
- 由於只有加減法操作,可以用
unsigned long long
儲存,能承受大概18*Mod*Mod
的範圍,所以常規範圍下可以不取模,範圍較大就中間取模,儘量減少次數。
附送正常版的 NTT:
//Author:RingweEH
//P3803 【模板】多項式乘法(FFT)
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin),p1==p2)?EOF:*p1++)
char buf[1<<20],*p1=buf,*p2=buf;
int read()
{
int x=0; char ch=getchar();
while ( ch>'9' || ch<'0' ) ch=getchar();
while ( ch<='9' && ch>='0' ) x=x*10+ch-48,ch=getchar();
return x;
}
void write( int x ) { if(x<0) putchar(45); if ( x>9 ) write(x/10); putchar(x%10+48); }
void swap( ll &a,ll &b ) { a^=b; b^=a; a^=b; }
const int N=2100000;
const ll Mod=998244353,G=3,InvG=332748118;
int n,m,rev[N];
ll f[N],g[N],Invn;
ll power( ll a,ll b=Mod-2 )
{
ll res=1;
for ( ; b; b>>=1,a=a*a%Mod )
if ( b&1 ) res=res*a%Mod;
return res;
}
void NTT( ll *f,bool fl )
{
int i,k,len,l;
for ( i=0; i<n; i++ )
if ( i<rev[i] ) swap( f[i],f[rev[i]] );
for ( k=2; k<=n; k<<=1 )
{
len=k>>1; ll nwG=power( fl ? G : InvG,(Mod-1)/k );
for ( i=0; i<n; i+=k )
{
ll buf=1;
for ( l=i; l<i+len; l++ )
{
ll tmp=buf*f[len+l]%Mod;
f[len+l]=f[l]-tmp;
if ( f[len+l]<0 ) f[len+l]+=Mod;
f[l]=f[l]+tmp;
if ( f[l]>=Mod ) f[l]-=Mod;
buf=buf*nwG%Mod;
}
}
}
}
int main()
{
n=read(); m=read(); int i;
for ( i=0; i<=n; i++ ) f[i]=read();
for ( i=0; i<=m; i++ ) g[i]=read();
for ( m+=n,n=1; n<=m; n<<=1 );
for ( i=0; i<n; i++ ) rev[i]=(rev[i>>1]>>1)|((i&1) ? n>>1 : 0);
NTT( f,1 ); NTT( g,1 );
for ( i=0; i<n; i++ ) f[i]=f[i]*g[i]%Mod;
NTT( f,0 ); Invn=power(n);
for ( i=0; i<=m; i++ )
write((int)(f[i]*Invn%Mod)),putchar(32);
return 0;
}