1. 程式人生 > >FFT&NTT小結

FFT&NTT小結

我們 stat 矢量 splay pair 傳送門 mman fin problem

Preface

最近幾天學了一下FTT和NTT,感覺這東西理解了之後也沒有那麽難 其實我IDFT還不會證明

我本來是準備寫一篇特別詳細的總結,結果發現了一篇和我想寫的內容相近的博客 傳送門

以及一篇只需初中數學知識的零基礎學習筆記 傳送門

所以我就只講一下大致的算法過程,具體可以去看一下鏈接的兩篇博客

先了解復數相關的性質和運算以及單位圓後,食用效果更佳

Problem

題目藍鏈

給你兩個多項式,求它們的卷積

Process

1.DFT

將給定的兩個多項式從系數式轉換為點值式

點值式的意思是在平面上找到\(n + 1\)個橫坐標不同的點來確定一個\(n\)次函數,即\(F(x)\)對於若幹個不同\(x\)

的取值

直接暴力顯然是不行的,所以我們要想辦法優化

我們可以考慮把當前的問題轉化為子問題,然後再從子問題快速求解當前的問題

假設我們現在要求\(F(x) = \sum\limits_{i = 0}^{n - 1} a_i \cdot x^i\)的點值式,保證\(n = 2^k (k \in N)\)

我們可以把式子變一下形
\[ F(x) = (a_0 + a_2 \cdot x^2 + \cdots + a_{n - 2} \cdot x^{n - 2}) + (a_1 \cdot x + a_3 \cdot x^3 + \cdots + a_{n - 1} \cdot x^{n - 1}) \]


我們令
\[ G(x) = a_0 + a_2 \cdot x + \cdots + a_{n - 2} \cdot x^{\frac{n}{2} - 1} \G‘(x) = a_1 + a_3 \cdot x + \cdots + a_{n - 1} \cdot x^{\frac{n}{2} - 1} \]

\[ F(x) = G(x^2) + x \cdot G‘(x^2) \]
但這樣好像還是沒有轉換為一模一樣的子問題,所以我們可以考慮帶一些具有(qi)某些(qi)特殊(guai)性質(guai)的數值進去

在經過前人無數次嘗試之後,發現可以代入\(\omega\)到式子裏去,這是因為\(\omega\)

有一些比較神奇的性質

\(\omega\)的本質是一個復數,且滿足\(\omega^n = 1\),所以顯然\(\omega\)只能在單位圓上

於是我們記\(\omega_n^k (k \in [0, n))\)為單位根,如果我們把這些單位根看成矢量,那麽它們便會\(n\)等分這個單位圓

它有這樣一些性質:(字母均為整數)

  • \(\omega_n^k = \omega_n^{k + a \cdot n}\)
  • \(\omega_n^{k_1} \cdot \omega_n^{k_2}= \omega_n^{k_1 + k_2}\)
  • \(\omega_{d \cdot n}^{d \cdot k} = \omega_n^k\)
  • \(\omega_n^{k + \frac{n}{2}} = - \omega_n^{k}\)

所以,(保證\(k < \frac{n}{2}\))
\[ F(\omega_n^k) = G((\omega_n^k)^2) + \omega_n^k \cdot G‘((\omega_n^k)^2) \= G(\omega_n^{2k}) + \omega_n^k \cdot G‘(\omega_n^{2k}) \= G(\omega_{n / 2}^k) + \omega_n^k \cdot G‘(\omega_{n / 2}^k) \]
因為\(\omega_n^{k + n / 2} = - \omega_n^k\),所以
\[ F(\omega_n^{k + n / 2}) = G(\omega_{n / 2}^k) - \omega_n^k \cdot G‘(\omega_{n / 2}^k) \]
綜上,\(F\)函數的點值均可從函數\(G\)\(G‘\)轉移過來,復雜度\(\mathcal{O}(n)\)

分治後總時間復雜度\(\mathcal{O}(n \log n)\) (此\(n\)非彼\(n\))

2. 點值式相乘

我們直接把兩個點值式乘起來即為它們卷集的點值式

3. IDFT

將答案的點值式轉換為系數式,經過一番矩陣的巧妙證明之後 假裝自己會

只需要令\(\omega_n^k = \omega_n^{-k}\)後進行一次DFT即為IDFT

至此,我們便得到了答案多項式的系數

Code

FFT
#include <bits/stdc++.h>

using namespace std;

#define fst first
#define snd second
#define mp make_pair
#define squ(x) ((LL)(x) * (x))
#define debug(...) fprintf(stderr, __VA_ARGS__)

typedef long long LL;
typedef pair<int, int> pii;

template<typename T> inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }
template<typename T> inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }

inline int read() {
    int sum = 0, fg = 1; char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == ‘-‘) fg = -1;
    for (; isdigit(c); c = getchar()) sum = (sum << 3) + (sum << 1) + (c ^ 0x30);
    return fg * sum;
}

namespace FFT {

    const int MAX_LEN = 1 << 21;
    const double PI = acos(-1.0);

    struct com {
        double a, b;
        com (double _a = 0.0, double _b = 0.0): a(_a), b(_b) { }
        com operator + (const com &t) const { return com(a + t.a, b + t.b); }
        com operator - (const com &t) const { return com(a - t.a, b - t.b); }
        com operator * (const com &t) const { return com(a * t.a - b * t.b, a * t.b + b * t.a); }
    };

    int len, cnt, rev[MAX_LEN];
    com g[MAX_LEN];

    void init(int N) {
        for (cnt = -1, len = 1; len <= N; len <<= 1) ++cnt;
        for (int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << cnt);
        g[0] = com(1.0, 0.0);
        com G(cos(PI * 2 / len), sin(PI * 2 / len));
        for (int i = 1; i < len; i++) g[i] = g[i - 1] * G;
    }

    void DFT(com *x, int op) {
        for (int i = 0; i < len; i++) if (i < rev[i]) swap(x[i], x[rev[i]]);
        for (int k = 2; k <= len; k <<= 1)
            for (int j = 0; j < len; j += k)
                for (int i = 0; i < k / 2; i++) {
                    com X = x[j + i], Y = x[j + i + k / 2] * g[~op ? len / k * i : len / k * (i ? k - i : i)];
                    x[j + i] = X + Y, x[j + i + k / 2] = X - Y;
                }
        if (op == -1) for (int i = 0; i < len; i++) x[i].a /= len;
    }

    void mul(int *a, int n, int *b, int m, int *c) {
        init(n + m);
        static com F[MAX_LEN], G[MAX_LEN], S[MAX_LEN];
        for (int i = 0; i < len; i++) F[i] = com(i <= n ? a[i] : 0.0, 0.0);
        for (int i = 0; i < len; i++) G[i] = com(i <= m ? b[i] : 0.0, 0.0);
        DFT(F, 1), DFT(G, 1);
        for (int i = 0; i < len; i++) S[i] = F[i] * G[i];
        DFT(S, -1);
        for (int i = 0; i <= n + m; i++) c[i] = round(S[i].a);
    }

}

const int maxn = 2e6 + 10;

int main() {
#ifdef xunzhen
    freopen("FFT.in", "r", stdin);
    freopen("FFT.out", "w", stdout);
#endif

    int n = read(), m = read();

    static int a[maxn], b[maxn], c[maxn];
    for (int i = 0; i <= n; i++) a[i] = read();
    for (int i = 0; i <= m; i++) b[i] = read();

    FFT::mul(a, n, b, m, c);

    for (int i = 0; i <= n + m; i++) printf("%d%c", c[i], i < n + m ? ‘ ‘ : ‘\n‘);

    return 0;
}
NTT
#include <bits/stdc++.h>

using namespace std;

#define fst first
#define snd second
#define mp make_pair
#define squ(x) ((LL)(x) * (x))
#define debug(...) fprintf(stderr, __VA_ARGS__)

typedef long long LL;
typedef pair<int, int> pii;

template<typename T> inline bool chkmax(T &a, const T &b) { return a < b ? a = b, 1 : 0; }
template<typename T> inline bool chkmin(T &a, const T &b) { return a > b ? a = b, 1 : 0; }

inline int read() {
    int sum = 0, fg = 1; char c = getchar();
    for (; !isdigit(c); c = getchar()) if (c == ‘-‘) fg = -1;
    for (; isdigit(c); c = getchar()) sum = (sum << 3) + (sum << 1) + (c ^ 0x30);
    return fg * sum;
}

namespace NTT {

    const int MAX_LEN = 1 << 21;
    const int mod = 998244353, g0 = 3;

    int len, cnt, rev[MAX_LEN], g[MAX_LEN];

    inline int add(int x, int y) { return (x += y) < mod ? (x >= 0 ? x : x + mod) : x - mod; }
    inline int mul(int x, int y) { return (LL)x * y % mod; }
    inline int Pow(int x, int y) {
        if (y < 0) y = -1LL * y * (mod - 2) % (mod - 1);
        int res = 1;
        for (; y; y >>= 1, x = mul(x, x)) if (y & 1) res = mul(res, x);
        return res;
    }

    void init(int N) {
        for (cnt = -1, len = 1; len <= N; len <<= 1) ++cnt;
        for (int i = 0; i < len; i++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << cnt);
        g[0] = 1;
        for (int G = Pow(g0, (mod - 1) / len), i = 1; i < len; i++) g[i] = mul(g[i - 1], G);
    }

    void DFT(int *x, int op) {
        for (int i = 0; i < len; i++) if (i < rev[i]) swap(x[i], x[rev[i]]);
        for (int k = 2; k <= len; k <<= 1)
            for (int j = 0; j < len; j += k)
                for (int i = 0; i < k / 2; i++) {
                    int X = x[j + i], Y = mul(x[j + i + k / 2], g[~op ? len / k * i : len / k * (i ? k - i : i)]);
                    x[j + i] = add(X, Y), x[j + i + k / 2] = add(X, -Y);
                }
        if (op == -1) for (int inv = Pow(len, -1), i = 0; i < len; i++) x[i] = mul(x[i], inv);
    }

    void mul(int *a, int n, int *b, int m, int *c) {
        init(n + m);
        static int F[MAX_LEN], G[MAX_LEN], S[MAX_LEN];
        for (int i = 0; i < len; i++) F[i] = i <= n ? a[i] : 0;
        for (int i = 0; i < len; i++) G[i] = i <= m ? b[i] : 0;
        DFT(F, 1), DFT(G, 1);
        for (int i = 0; i < len; i++) S[i] = mul(F[i], G[i]);
        DFT(S, -1);
        for (int i = 0; i <= n + m; i++) c[i] = S[i];
    }

}

const int maxn = 2e6 + 10;

int main() {
#ifdef xunzhen
    freopen("NTT.in", "r", stdin);
    freopen("NTT.out", "w", stdout);
#endif

    int n = read(), m = read();

    static int a[maxn], b[maxn], c[maxn];
    for (int i = 0; i <= n; i++) a[i] = read();
    for (int i = 0; i <= m; i++) b[i] = read();

    NTT::mul(a, n, b, m, c);

    for (int i = 0; i <= n + m; i++) printf("%d%c", c[i], i < n + m ? ‘ ‘ : ‘\n‘);

    return 0;
}

Summary

其實NTT就是把FFT在模意義下進行,我們可以找一個原根\(g\)來代替\(\omega\)

NTT可以用來避免浮點數的緩慢運算 但好像取模運算更滿(霧

IDFT就先留個坑,等以後再來填算了

FFT&NTT小結