「任意模數多項式乘法」
阿新 • • 發佈:2021-01-02
「任意模數多項式乘法」
前置知識
基本問題
給定一個 \(n\) 次多項式 \(F(x)\) 和一個 \(m\) 次多項式,求出
\[F(x)\times G(x) \]係數對 \(p\) 取模,且不保證 \(p\) 可以分解成 \(p=2^ka+1\) 之形式,\(0\leq a_i,b_i\leq 10^9\),\(2\leq p\leq 10^9+9\)
考慮直接用 \(FFT\),但是值域太大,\(long\; double\) 都炸了,精度也無法保證
直接用 \(NTT\),但是是任意模數,根本用不了
處理這種問題,我們常有兩種做法:
三模NTT
都 \(1202\) 年了,不會還有人寫三模 \(NTT\)
好吧,其實是我不會
另一種比較常用的做法是
MTT
既然 \(FFT\) 處理不了值域很大的情況,我們就從問題入手,將值域縮小
不妨將兩個多項式拆成:
\[F(x)=M\times A(x)+B(x) \]\[G(x)=M\times C(x)+D(x) \]當 \(M=2^{15}\) 時,可以完美避免炸 \(double\) 的問題
現在問題就轉化為了
\[(M\times A(x)+B(x))\times (M\times C(x)+D(x)) \]\[M^2A(x)C(x)+M(B(x)C(x)+A(x)D(x))+B(x)D(x) \]所以直接 \(7\) 次 \(FFT\)
\(PS\):可以省到 \(4\) 次 \(FFT\),但是我還不會
程式碼
#include <cstdio> #include <cstring> #include <iostream> #include <algorithm> #include <cmath> typedef long long ll; typedef unsigned long long ull; using namespace std; const int maxn = 3e5 + 50, INF = 0x3f3f3f3f; const double pi = acos (-1); inline int read () { register int x = 0, w = 1; register char ch = getchar (); for (; ch < '0' || ch > '9'; ch = getchar ()) if (ch == '-') w = -1; for (; ch >= '0' && ch <= '9'; ch = getchar ()) x = x * 10 + ch - '0'; return x * w; } inline void write (register int x) { if (x / 10) write (x / 10); putchar (x % 10 + '0'); } int n, m, M, mod, len = 1, bit; int rev[maxn], f[maxn]; struct Complex { double x, y; Complex () {} Complex (register double a, register double b) { x = a, y = b; } inline Complex operator + (const Complex &a) const { return Complex (x + a.x, y + a.y); } inline Complex operator - (const Complex &a) const { return Complex (x - a.x, y - a.y); } inline Complex operator * (const Complex &a) const { return Complex (x * a.y + y * a.x, y * a.y - x * a.x); } } g[maxn], a[maxn], b[maxn], c[maxn], d[maxn], omega[maxn]; inline void FFT (register int len, register Complex * a, register int opt) { for (register int i = 1; i < len; i ++) if (i < rev[i]) swap (a[i], a[rev[i]]); for (register int d = 1; d < len; d <<= 1) { for (register int i = 0; i < len; i += d << 1) { for (register int j = 0; j < d; j ++) { register Complex w = omega[len / (d << 1) * j]; w.x *= opt; register Complex x = a[i + j], y = w * a[i + j + d]; a[i + j] = x + y, a[i + j + d] = x - y; } } } } int main () { n = read(), m = read(), mod = read(), M = 1 << 15; while (len <= n + m) len <<= 1, bit ++; for (register int i = 0; i < len; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << bit - 1), omega[i] = Complex (sin (2 * pi * i / len), cos (2 * pi * i / len)); for (register int i = 0, x; i <= n; i ++) x = read(), a[i].y = x / M, b[i].y = x % M; for (register int i = 0, x; i <= m; i ++) x = read(), c[i].y = x / M, d[i].y = x % M; FFT (len, a, 1), FFT (len, b, 1), FFT (len, c, 1), FFT (len, d, 1); for (register int i = 0; i < len; i ++) g[i] = a[i] * c[i]; FFT (len, g, -1); for (register int i = 0; i <= n + m; i ++) f[i] = (f[i] + (ll) (g[i].y / len + 0.5) % mod * M % mod * M % mod) % mod; for (register int i = 0; i < len; i ++) g[i] = a[i] * d[i] + b[i] * c[i]; FFT (len, g, -1); for (register int i = 0; i <= n + m; i ++) f[i] = (f[i] + (ll) (g[i].y / len + 0.5) % mod * M % mod) % mod; for (register int i = 0; i < len; i ++) g[i] = b[i] * d[i]; FFT (len, g, -1); for (register int i = 0; i <= n + m; i ++) f[i] = (f[i] + (ll) (g[i].y / len + 0.5) % mod) % mod; for (register int i = 0; i <= n + m; i ++) printf ("%d ", f[i]); putchar ('\n'); return 0; }