FFT/NTT複習筆記&多項式&生成函式學習筆記
眾所周知,tzc 在 2019 年(12 月 31 日)就第一次開始接觸多項式相關演算法,可到 2021 年(1 月 1 日)才開始寫這篇 blog。
感覺自己開了個大坑(
多項式乘法
好吧這個應該是多項式各種運算中的基礎了。
首先,在學習多項式乘法之前,你需要學會:
複數
我們定義虛數單位 \(i\) 為滿足 \(x^2=-1\) 的 \(x\)。
那麼所有的複數都可以表示為 \(z=a+bi\) 的形式,其中 \(a,b\) 均為實數。
複數的加減直接對實部虛部相加減就行了。
複數的乘法稍微用下乘法分配律就有 \((a+bi)(c+di)=(ac-bd)+(ad+bc)i\)
複平面
我們以實部為橫座標,虛部為縱座標建立平面直角座標系,我們稱得到的平面為“複平面
例如 \((1,0)\) 就表示 \(1\),\((-1,1)\) 就表示 \(-1+i\)
定義複數的模長為其所表示的點與原點間的距離,即 \(|z|=\sqrt{a^2+b^2}\)
我們知道,高中階段會學兩種座標系,一是平面直角座標系,二是極座標系。前面我們提到的 \((a,b)\) 就是 \(z\) 在平面直角座標系下的座標,據此也可類比出 \(z\) 在極座標系下的座標 \((r,\theta)\)。顯然 \(r\) 就是複數 \(z\) 的模長,\(\theta\) 為複數 \(z\) 所表示的向量 \((a,b)\)
根據三角函式顯然有 \(a=r\cos\theta,b=r\sin\theta\)
考慮將兩個用極座標表示的複數 \((r_1,\theta_1),(r_2,\theta_2)\) 相乘:
\((r_1,\theta_1)(r_2,\theta_2)\)
\(=(r_1\cos\theta_1+r_1\sin\theta_1i)(r_2\cos\theta_2+r_2\sin\theta_2i)\)
\(=(r_1r_2\cos\theta_1\cos\theta_2-r_1r_2\sin\theta_1\sin\theta_2)+(r_1r_2\sin\theta_1\cos\theta_2+r_1r_2\cos\theta_1\sin\theta_2)i\)
\(=r_1r_2(cos(\theta_1+\theta_2)+\sin(\theta_1+\theta_2))=(r_1r_2,\theta_1+\theta_2)\)
於是我們得到複數乘法的原則:模長相乘,幅角相加。
單位根
定義 \(n\) 次單位根為滿足 \(z^n=1\) 的複數 \(z\)。
考慮什麼樣的複數 \(z\) 滿足條件,假設 \(z\) 寫成極座標形式為 \((r,\theta)\),那麼 \(z^n=(r^n,n\theta)\)。
而 \(z^n=(1,0)\)。故這樣的複數的模長均為 \(1\),幅角乘 \(n\) 為 \(2\pi\) 的整數倍。
故 \(r=1,\theta=\dfrac{2k\pi}{n}\),其中 \(k\) 為整數。
定義 \(\omega_n^i\) 為 \((1,\dfrac{2i\pi}{n})=\cos(\dfrac{2i\pi}{n})+i\sin(\dfrac{2i\pi}{n})\)。
單位根有如下性質(由於都比較顯然就不一一證明了):
- \(\omega_n^i=\omega_n^{i+n}\)
- \(\omega_n^i=(\omega_n^1)^i\)
- 若 \(n\) 為偶數,則 \(\omega_n^i=-\omega_n^{i+n/2}\)
- \(\omega_i^n=\omega_{i/2}^{n/2}\)
單位根反演
qwq 這東西在三週前 XES 講母函式的時候曾經講過
單位根反演就是如下等式:
\([i\equiv 0\pmod{n}]=\dfrac{1}{n}\sum\limits_{t=0}^{n-1}(\omega_n^{i})^t\)
證明:
若 \(i\) 不是 \(n\) 的倍數,則 \(\omega_i\neq 1\),\(\dfrac{1}{n}\sum\limits_{t=0}^{n-1}(\omega_n^{i})^t=\dfrac{(\omega_n^i)^n-1}{\omega_n^i-1}\),由於 \((\omega_i)^n-1=0,\omega_i-1\neq 0\),故原式 \(=0\)。
若 \(i\) 是 \(n\) 的倍數,則 \(\omega_i=1\),式子中每項都是 \(1\),加起來除以 \(n\) 就是 \(1\)。
下面終於進入正題了:
快速傅立葉變換
對於多項式 \(A(x)=\sum\limits_{i=0}^na_ix^i,B(x)=\sum\limits_{i=0}^nb_ix^i\),設 \(C(x)=A(x)B(x)=\sum\limits_{i=0}^{n+m}c_ix^i\)
那麼顯然 \(c_i=\sum\limits_{x+y=i}a_xb_y\)
設 \(N\) 為大於 \(n+m\) 且最小的滿足 \(N=2^k\)(\(k\) 為整數)的 \(N\)
由於 \(N>n+m\),所有 \(x+y=i\) 等價於 \(x+y-i\equiv 0\pmod{N}\)
至於 \(N\) 為什麼要是 \(2\) 的整數次冪,後面再說。
繼續推式子:
\(c_i=\sum\limits_{x+y-i\equiv 0\pmod{N}}a_xb_y\)
\(=\sum\limits_{x,y}a_xb_y\times\dfrac{1}{N}\sum\limits_{t=0}^{N-1}(\omega_N^{x+y-i})^t\)
\(=\dfrac{1}{N}\sum\limits_{x,y}\sum\limits_{t=0}^{N-1}\omega_N^{(x+y-i)t} a_xb_y\)
\(=\dfrac{1}{N}\sum\limits_{x,y}\sum\limits_{t=0}^{N-1}\omega_N^{-it} a_x\omega_N^{xt}b_y\omega_N^{yt}\)
\(=\dfrac{1}{N}\sum\limits_{t=0}^{N-1}\omega_N^{-it} \times\sum\limits_{x}a_x\omega_N^{xt}\times\sum\limits_{y}b_y\omega_N^{yt}\)
記 \(\hat{a}_t=\sum\limits_{i}a_i\omega_{N}^{it},\hat{b}_t=\sum\limits_{i}b_i\omega_{N}^{it}\),i.e,\(a_t\) 是將 \(\omega_N^t\) 代入多項式 \(A\) 後計算得到的結果,\(b_t\) 是將 \(\omega_n^t\) 代入多項式 \(B\) 後計算得到的結果。
則 \(c_i=\dfrac{1}{N}\sum\limits_{t=0}^{N-1}\omega_N^{-it} \hat{a}_t\hat{b}_t\)
稍微觀察一下即可發現,若記 \(d_t=\hat{a}_t\hat{b}_t\),那麼 \(c_i\) 就是 \(\omega_N^{-i}\) 代入多項式 \(D\) 後計算得到的結果。
是不是感覺與前面計算 \(\hat{a}_t,\hat{b}_t\) 的過程如出一轍?只不過 \(\omega_{N}^t\) 變成了 \(\omega_{N}^{-t}\) ?並且最後係數除了個 \(N\)?
故我們只需考慮計算 \(\hat{a}_t,\hat{b}_t\) 的過程,根據 \(\hat{a}_t,\hat{b}_t\) 計算 \(c_i\) 的過程同理即可。
考慮分治地計算 \(\hat{a},\hat{b}\)。
將 \(a\) 按照下標分為奇數和偶數兩部分,每部分都是一個長為 \(\dfrac{N}{2}\) 的陣列,不妨設兩個陣列為 \(a_{0,i}\) 和 \(a_{1,i}\),其中 \(a_{0,i}=a_{2i},a_{1,i}=a_{2i+1}\)
那麼顯然有 \(\hat{a}_t=\sum\limits_{i=0}^{N/2-1}\omega_{N}^{2it}a_{0,i}+\sum\limits_{i=0}^{N/2-1}\omega_{N}^{(2i+1)t}a_{1,i}\)。
把右邊的 \(\omega_N^{(2i+1)t}\) 拆成 \(\omega_{N}^{2it}\times \omega_{N}^t\) 可得:
\(\hat{a}_t=\sum\limits_{i=0}^{N/2-1}\omega_{N}^{2it}a_{0,i}+\omega_{N}^t\sum\limits_{i=0}^{N/2-1}\omega_{N}^{2it}a_{1,i}\)
而由 \(\omega_i^n=\omega_{i/2}^{n/2}\) 得 \(\hat{a}_t=\sum\limits_{i=0}^{N/2-1}\omega_{N/2}^{it}a_{0,i}+\omega_{N}^t\sum\limits_{i=0}^{N/2-1}\omega_{N/2}^{it}a_{1,i}\)
設 \(\hat{a}_{0,t}=\sum\limits_{i=0}^{N/2-1}\omega_{N/2}^{it}a_{0,i},\hat{a}_{1,t}=\sum\limits_{i=0}^{N/2-1}\omega_{N/2}^{it}a_{1,i}\)
分治地計算 \(\hat{a}_{0,i},\hat{a}_{0,1}\)
最後 \(\hat{a}_t=\hat{a}_{0,t}+\omega_{N}^t\hat{a}_{1,t}\)
不過注意,按照定義域這裡的 \(t\) 是在 \([0,N/2-1]\) 範圍內的,因此對於 \(t\in[0,N/2-1]\) 才能用這個式子計算。不過注意到 \(\hat{a}_{t+N/2}=\hat{a}_{0,t}+\omega_N^{t+N/2,N}\hat{a}_{1,t}=\hat{a}_{0,t}-\omega_{N}^t\hat{a}_{1,t}\),用這個公式可以計算出 \(\hat{a}_{t+N/2}\)
上述推理過程寫成程式碼的形式如下:
const int MAXN=1<<21;//pay sepcial attention to array size
struct comp{
double x,y;//(real,imag)
comp(){x=y=0;}
comp(double _x,double _y){x=_x;y=_y;}
friend comp operator +(comp lhs,comp rhs){return comp(lhs.x+rhs.x,lhs.y+rhs.y);}
friend comp operator -(comp lhs,comp rhs){return comp(lhs.x-rhs.x,lhs.y-rhs.y);}
friend comp operator *(comp lhs,comp rhs){return comp(lhs.x*rhs.x-lhs.y*rhs.y,lhs.x*rhs.y+lhs.y*rhs.x);}
} f[MAXN+5],g[MAXN+5],h[MAXN+5];
const double Pi=acos(-1);//the value of Pi
int n,m,LEN=1;//LEN is the smallest N such that N>n+m and N=2^k (k is integer)
void FFT(comp *a,int len,int type){
if(len==1) return;//the point value of a constant is the constant itself
comp a0[len>>1],a1[len>>1];//a0[i] and a1[i]
for(int i=0;i<len;i+=2) a0[i>>1]=a[i],a1[i>>1]=a[i+1];
FFT(a0,len>>1,type);FFT(a1,len>>1,type);//find the value of \hat{a}_0 and \hat{a}_1
comp W=comp(cos(2*Pi/len),type*sin(2*Pi/len)),w=comp(1,0);//omega_{len}^1
for(int i=0;i<len;i++,w=w*W){
if(i<(len>>1)) a[i]=a0[i]+w*a1[i];
else a[i]=a0[i-(len>>1)]+w*a1[i-(len>>1)];
}
}
int main(){
scanf("%d%d",&n,&m);
for(int i=0;i<=n;i++) scanf("%lf",&f[i].x);
for(int i=0;i<=m;i++) scanf("%lf",&g[i].x);
while(LEN<=(n+m)) LEN<<=1;
FFT(f,LEN,1);FFT(g,LEN,1);
for(int i=0;i<LEN;i++) h[i]=f[i]*g[i];
FFT(h,LEN,-1);for(int i=0;i<=n+m;i++) printf("%d ",(int)(h[i].x/LEN+0.5));//remember to divide c[i] by N
return 0;
}
在上面的程式碼中,有一些點需要注意:
- FFT 陣列的大小需要注意,假設兩個多項式長度最大值分別為 \(n,m\),那麼你所開陣列的大小應至少為 \(2^{\lceil\log_2(n+m)\rceil}\)
- 最後別忘了將陣列中所有值除以 \(N\)。
迭代 FFT
由於 FFT 常數很大,需要進行優化。
如圖所示,我們手推一下遞迴的過程,先將其分成了 \(0,2,4,6\) 和 \(1,3,5,7\) 兩組,又將 \(0,2,4,6\) 分成了 \(0,4\) 和 \(2,6\) 兩組,將 \(1,3,5,7\) 分成了 \(1,5\) 和 \(3,7\) 兩組。
然後對 \(0,4\) 進行合併,\(2,6\) 進行合併,\(1,5\) 進行合併,\(3,7\) 進行合併;然後合併長度為 \(4\) 的區間 \(0,4,2,6\) 和 \(1,5,3,7\),最後合併長度為 \(8\) 的區間 \(0,4,2,6,1,5,3,7\)。
於是我們考慮把 \(a_0,a_1,a_2,\dots,a_8\) 交換位置交換到 \(a_0,a_4,a_2,a_6,a_1,a_5,a_3,a_7\) 的位置。
然後合併長度為 \(2\) 的區間 \([2i,2i+1]\),再合併 \(4\) 的區間 \([4i,4i+3]\),然後是 \(8\) 的區間,以此類推。
那麼變換前的下標和變換後的下標有什麼規律呢?
不難發現 \(1\) 的二進位制是 \(001\),\(4\) 的二進位制是 \(100\),\(1\) 與 \(4\) 二進位制下剛好互為翻轉串。然後你發現 \(2\) 與 \(2\),\(3\) 與 \(6\),\(5\) 與 \(5\) 也有這樣的規律。
因此我們就可以得出這樣的規律:變換後下標為 \(i\) 的數在原來的 \(a_i\) 中的下標在二進位制下互為翻轉串。
至於怎麼求 \(i\) 變換後的下標 \(rev_i\),暴力地 \(n\log n\) 求肯定是 ok 的。不過考慮到常數的原因,最好用類似數位 \(dp\) 的方法達到 \(O(n)\) 的複雜度。
for(int i=0;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(LOG-1));
這樣我們就有了不用遞迴的實現方式:
void FFT(comp *a,int len,int type){
for(int i=0;i<len;i++) if(i<rev[i]) swap(a[i],a[rev[i]]);
for(int i=1;i<len;i<<=1){
comp W=comp(cos(Pi/i),type*sin(Pi/i));
for(int j=0;j<len;j+=i*2){
comp w(1,0);
for(int k=0;k<i;k++,w=w*W){
comp x=a[j+k],y=w*a[i+j+k];
a[j+k]=x+y;a[i+j+k]=x-y;
}
}
}
}
通過以上探究,我們學會了 FFT,它有不少優點,譬如它能在高效地求出兩個多項式的卷積。但同時它也有許多缺點,例如使用 FFT 可能會出現精度問題(因為使用了 double),並且它也不支援取模等等。那假如我們真的需要取模怎麼辦呢?
這就需要 NTT。