1. 程式人生 > 其它 >淺析快速傅立葉變換

淺析快速傅立葉變換

簡介

快速傅立葉變換(Fast Fourier Transform)是一種可以在 $O(n\log n) $ 複雜度下完成離散傅立葉變換(Discrete Fourier Transfrom)的演算法,常應用於加速多項式乘法。

多項式表示法

係數表示法

係數表示法就是用多項式各項係數來表達這個多項式:

\[f(x)=a_0+a_1x+\cdots +a_nx^n\Leftrightarrow f(x)=\{a_0,a_1,\cdots,a_n\} \]

點值表示法

點值表示法就是把多項式看作一個函式,對於一個 \(n\) 次多項式,任取 \(n+1\) 個在函式上的點,這樣可以唯一確定這個多項式:

\[f(x)=a_0+a_1x+\cdots a_nx^n\Leftrightarrow f(x)=\{y_0,y_1,\cdots,y_n\}(\forall i,\exist x\text{ s.t.} f(x)=y_i) \]

複數

在推導傅立葉變換前,我們需要掌握一些複數的基本性質:

  • 複數運算滿足結合律/交換律/分配律
  • 複數 \(z=a+bi\) 的模長 \(|z|=\sqrt{a^2+b^2}\) ,幅角 \(\theta\) 為實軸的正半軸逆時針旋轉到 \(z\) 的有向角度
  • 兩個複數的乘法滿足模長相乘,幅角相加

單位根

定義

將複平面上的單位圓等分成 \(n\) 個部分,定義其中幅角為正且最小的等分點對應的複數為 \(n\)

次單位根,記作 \(\omega_n\) ,那麼其餘的 \(n-1\) 個等分點對應的複數分別為 \(\omega_n^2,\omega_n^3,\cdots,\omega_n^n\) ,其中 \(\omega_n^n=\omega_n^0=1\) ,一般地,有:

\[\omega_n^k=\cos(2\pi\cdot \frac{k}{n})+i\sin(2\pi\cdot\frac{k}{n}) \]

\(n=4\) 時影象如下:

折半定理

\[\omega_{2n}^{2k}=\omega_n^k \]

由幾何意義/代入公式即可證明

消去定理

\[\omega_n^{k+\frac2n}=-\omega_n^k \]

由幾何意義/代入公式即可證明

離散傅立葉變換

考慮一個含 \(n\) 項( \(n=2^t,t\in\mathbb{N}\) )的多項式 \(A(x)\) ,已知它的係數表示,將 \(n\) 次單位根的 \(0\sim n-1\) 次冪分別代入 \(A(x)\) 得到它的點值表示,這一過程稱為離散傅立葉變換(Discrete Fourier Transform)

如果樸素地代入求值,複雜度顯然為 \(O(n^2)\)FFT利用了單位根的一些性質來降低複雜度,對於 \(A(x)=a_0+a_1x+\cdots+a_{n-1}x^{n-1}\) ,我們按照奇偶進行分組:

\[\begin{aligned} A(x)&=(a_0+a_2x^2+\cdots+a_{n-2}x^{n-2})+(a_1x+a_3x^3+\cdots+a_{n-1}x^{n-1})\\ &=(a_0+a_2x^2+\cdots+a_{n-2}x^{n-2})+x(a_1+a_3x^2+\cdots+a_{n-2}x^{n-1}) \end{aligned} \]

\[A_1(x)=a_0+a_2x+\cdots+a_{n-2}x^{\frac{n-2}{2}}\\A_2(x)=a_1+a_3x+\cdots+a_{n-1}x^{\frac{n-2}{2}} \]

可以得到:

\[A(x)=A_1(x^2)+xA_2(x^2) \]

分類討論,當 \(0\leq k\leq \frac n 2-1\)

\[\begin{aligned} A(\omega_n^k)&=A_1(\omega_n^{2k})+\omega_n^kA_2(\omega_n^{2k})\\ &=A_1(\omega_{\frac n 2}^k)+\omega_n^kA_2(\omega_{\frac n 2}^k) \end{aligned} \]

\(\frac n 2\leq k+\frac n 2\leq n-1\)

\[\begin{aligned} A(\omega_n^{k+\frac n 2})&=A_1(\omega_n^{2k+n})+\omega_n^{k+\frac n 2}A_2(\omega_n^{2k+n})\\ &=A_1(\omega_n^n\cdot\omega_n^{2k})-\omega_n^kA_2(\omega_n^n\cdot\omega_n^{2k})\\ &=A_1(\omega_{\frac n 2}^k)-\omega_n^kA_2(\omega_{\frac n 2}^k) \end{aligned} \]

所以,如果求出了 \(A_1(x),A_2(x)\) 分別在 \(\omega_{\frac n 2}^0,\omega_{\frac n 2}^1,\cdots,\omega_{\frac n 2}^{\frac n 2-1}\) 的值,就可以用 \(O(n)\) 求出 \(A(\omega_n^0),A(\omega_n^1),\cdots,A(\omega_n^{n-1})\) ,那麼就得到了 \(A(x)\) 的點值表示

FFT的時間複雜度 \(T(n)\) 滿足:

\[T(n)=2T(\frac n 2)+n\Rightarrow T(n)=O(n\log n) \]

逆離散傅立葉變換

已知一個項數為 \(2\) 的次冪的多項式的點值表示,求它的係數表示,這一過程叫做逆離散傅立葉變換(Inverse Discrete Fourier Transform) ,我們仍可以在稍加變形後用FFT解決這一問題

\(\{d_0,d_1,\cdots,d_{n-1}\}\) 為多項式 \(\{a_0,a_1,\cdots,a_{n-1}\}\) 經過FFT得到的結果,即 \(d_i=A(\omega_n^i)\) ,構造一個多項式:

\[F(x)=d_0+d_1x+\cdots+d_{n-1}x^{n-1} \]

\(c_k=F(\omega_n^{-k})=\sum_{i=0}^{n-1}d_i\cdot(\omega_n^{-k})^i\)

那麼有:

\[\begin{aligned} c_k&=\sum_{i=0}^{n-1}d_i\cdot(\omega_n^{-k})^i\\ &=\sum_{i=0}^{n-1}[\sum_{j=0}^{n-1}a_j\cdot(\omega_n^i)^j]\cdot(\omega_n^{-k})^i\\ &=\sum_{j=0}^{n-1}a_j\sum_{i=0}^{n-1}(\omega_n^i)^{j-k} \end{aligned} \]

\(S(j,k)=\sum_{i=0}^{n-1}(\omega_n^i)^{j-k}\)

\(j=k\)\(S(j,k)=n\)

\(j\not= k\) ,根據等比求和公式有:

\[S(j,k)=\frac{\omega_n^0[(\omega_n^{j-k})^n-1]}{\omega_n^{j-k}-1}=\frac{(\omega_n^n)^{j-k}-1}{\omega_n^{j-k}-1}=\frac{1-1}{\omega_n^{j-k}-1}=0 \]

所以 \(\forall j,k,S(j,k)=[j=k]\cdot n\)

代入原式得:

\[c_k=\sum_{j=0}^{n-1}a_jS(j,k)=a_k\cdot n\Rightarrow a_k=\frac{c_k}{n} \]

FFT模板

遞迴

typedef complex<double> cp;
const int MAX_N = 1 << 20;

//if FFT, inv = false
bool inv = false;

//return w(n, k) or w(n, -k)
cp omega(int n, int k)
{
    if(inv)
        return cp(cos(2 * M_PI / n * k), sin(2 * M_PI / n * k));
    return cp(cos(2 * M_PI / n * k), -sin(2 * M_PI / n * k));
}

void fft(cp *a, int n)
{
    if(n == 1)
        return;
    static cp buf[MAX_N];
    int mid = n >> 1;
    for(int i = 0; i < mid; i++) {
        buf[i] = a[i << 1];
        buf[i + mid] = a[i << 1 | 1];
    }
    memcpy(a, buf, sizeof(cp) * (n + 1));

    cp *a1 = a, *a2 = a + mid;
    fft(a1, mid);
    fft(a2, mid);

    for(int i = 0; i < mid; i++) {
        cp t = omega(n, i);
        buf[i] = a1[i] + t * a2[i];
        buf[i + mid] = a1[i] - t * a2[i];
    }
    memcpy(a, buf, sizeof(cp) * (n + 1));
}

優化

遞迴版本的FFT需要輔助陣列,並且遞迴產生了較大的常數,所以我們把每次分組的情況列舉出來嘗試優化

觀察到每一個位置的數其實都是原來位置上的數的二進位制位翻轉了一下

於是我們可以先把原陣列調整成最底層的位置,然後從倒數第二層逐層向上計算,這就是FFT的 Cooley-Tukey 演算法,在這一演算法中,合併操作被稱為蝴蝶操作

\(1\) 開始由上到下對每一層編號,則從第 \(i\) 層到第 \(i-1\) 層需要 \(2^{i-1}\) 次合併。假設 \(A_1(\omega_{\frac n 2}^k)\)\(A_2(\omega_{\frac n 2}^k)\) 分別存在 \(a[k]\)\(a[k+\frac n 2]\) 中, \(A(\omega_n^k)\)\(A(\omega_n^{k+\frac n 2})\) 將要被存放在 \(buf[k]\)\(buf[k+\frac n 2]\) 中,合併的單位操作可表示為:

\[buf[k]:=a[k]+\omega_n^ka[k+\frac n 2]\\ buf[k+\frac n 2]=a[k]-\omega_n^ka[k+\frac n 2] \]

加入一個臨時變數並改變合併順序,我們就可以在原陣列內合併

\[t:=\omega_n^k\cdot a[k+\frac n 2]\\ a[k + \frac n 2]:=a[k]-t\\ a[k]:=a[k]+t \]
typedef complex<double> cp;

const int MAX_N = 1 << 22;
const double PI = acos(-1.0);

cp omega[MAX_N], inv[MAX_N];
cp x1[MAX_N], x2[MAX_N];
int sum[MAX_N << 1];

void init(int n)
{
    for(int i = 0; i < n; i++) {
        double a = cos(2 * PI / n * i), b = sin(2 * PI / n * i);
        omega[i] = cp(a, b);
        inv[i] = cp(a, -b);
    }
}

void transform(cp *a, int n, const cp *omega)
{
    for(int i = 0, j = 0; i < n; i++) {
        if(i > j)
            swap(a[i], a[j]);
        for(int l = n >> 1; (j ^= l) < l; l >>= 1)
            continue;
    }
    for(int i = 2; i <= n; i <<= 1) {
        int mid = i >> 1;
        for(cp *p = a; p != a + n; p += i) {
            for(int j = 0; j < mid; j++) {
                cp t = omega[n / i * j] * p[mid + j];
                p[mid + j] = p[j] - t;
                p[j] = p[j] + t;
            }
        }
    }
}

void dft(cp *a, int n)
{
    transform(a, n, omega);
}

void idft(cp *a, int n)
{
    transform(a, n, inv);
    for(int i = 0; i < n; i++)
        a[i] /= n;
}

多項式乘法

原理

考慮已知兩個多項式的係數表示 \(A(x)=\{a_0,a_1,\cdots,a_n\},B(x)=\{b_0,b_1,\cdots b_m\}\) ,要求它們的乘積的係數表示 \(C(x)=\{c_0,c_1,\cdots,c_{m+n}\}\) ,可以得到:

\[c_i=\sum_{j+k=i}a_jb_k \]

這樣做的複雜度為 \(O(n\times m)\) ,可以用這段程式碼表示:

for(int i = 0; i < n; i++)
	for(int j = 0; j < m; j++)
		c[i + j] += a[i] * b[j];

考慮如何用點值表示簡化計算,對於任意 \(n,m\) ,可以找到一個 \(t\) 滿足 \(2^t\geq2\max(n,m)\)\(2^{t-1}<2\max(n,m)\) ,我們把 \(A(x),B(x)\) 寫成 \(t\) 次多項式的形式,即:

\[A(x)=\{a_0,a_1,\cdots,a_n,0,0,\cdots\}\\ B(x)=\{b_0,b_1,\cdots,b_n,0,0,\cdots\} \]

再用DFT得到 \(A(x),B(x)\) 的點值表示,可以用 \(O(t)\) 推出 \(C(x)\) 的點值表示:

\[A(x)=\{x_0,x_1,\cdots,x_t\}\\ B(x)=\{y_0,y_1,\cdots,y_t\}\\ \Rightarrow C(x)=\{x_0y_0,x_1y_1,\cdots,x_ty_t\} \]

再用IDFT\(C(x)\) 的點值表示轉化為係數表示即可

例題

P1919 A*B Problem

高精度乘法運算可以看作多項式的乘法運算,求出多項式乘法結果後代入 \(x=10\) 即可

#include<bits/stdc++.h>
using namespace std;
typedef complex<double> cp;

const int MAX_N = 1 << 22;
const double PI = acos(-1.0);

cp omega[MAX_N], inv[MAX_N];
cp x1[MAX_N], x2[MAX_N];
int sum[MAX_N << 1];

void init(int n)
{
    for(int i = 0; i < n; i++) {
        double a = cos(2 * PI / n * i), b = sin(2 * PI / n * i);
        omega[i] = cp(a, b);
        inv[i] = cp(a, -b);
    }
}

void transform(cp *a, int n, const cp *omega)
{
    for(int i = 0, j = 0; i < n; i++) {
        if(i > j)
            swap(a[i], a[j]);
        for(int l = n >> 1; (j ^= l) < l; l >>= 1)
            continue;
    }
    for(int i = 2; i <= n; i <<= 1) {
        int mid = i >> 1;
        for(cp *p = a; p != a + n; p += i) {
            for(int j = 0; j < mid; j++) {
                cp t = omega[n / i * j] * p[mid + j];
                p[mid + j] = p[j] - t;
                p[j] = p[j] + t;
            }
        }
    }
}

void dft(cp *a, int n)
{
    transform(a, n, omega);
}

void idft(cp *a, int n)
{
    transform(a, n, inv);
    for(int i = 0; i < n; i++)
        a[i] /= n;
}

int main()
{
    string s1, s2;
    cin >> s1 >> s2;
    int len = 1, len1 = s1.size(), len2 = s2.size();
    while(len < len1 * 2 || len < len2 * 2)
        len <<= 1;
    for(int i = 0; i < len1; i++)
        x1[i] = cp(s1[len1 - i - 1] - '0');
    for(int i = len1; i < len; i++)
        x1[i] = cp(0);
    for(int i = 0; i < len2; i++)
        x2[i] = cp(s2[len2 - i - 1] - '0');
    for(int i = len2; i < len; i++)
        x2[i] = cp(0);
    init(len);
    dft(x1, len);
    dft(x2, len);
    for(int i = 0; i < len; i++)
        x1[i] = x1[i] * x2[i];
    idft(x1, len);
    for(int i = 0; i < len; i++)
        sum[i] = int(x1[i].real() + 0.5);
    for(int i = 0; i < len; i++) {
        sum[i + 1] += sum[i] / 10;
        sum[i] %= 10;
    }
    len = len1 + len2 - 1;
    while(sum[len] == 0 && len > 0)
        len--;
    for(int i = len; i >= 0; i--)
        putchar(sum[i] + '0');
    putchar('\n');
    return 0;
}