1. 程式人生 > 實用技巧 >多項式基礎:FFT與NTT

多項式基礎:FFT與NTT

本文為基礎部分。

多項式進階:多項式的高階運算

相似演算法:快速沃爾什變換(FWT)

FFT與NTT用來處理多項式乘法。

快速傅立葉變換(FFT)

小學生都能看懂的FFT!!!

\(Code\):

struct Complex {
	double x, y;
	Complex(double xx = 0, double yy = 0) {x = xx, y = yy;}
	Complex operator + (const Complex &i) const {
		return Complex(x + i.x, y + i.y);
	}
	Complex operator - (const Complex &i) const {
		return Complex(x - i.x, y - i.y);
	}
	Complex operator * (const Complex &a) const {
		return Complex(x * a.x - y * a.y, x * a.y + y * a.x);
	}
}A[N], B[N];
int n, m, limi = 1, l;
int r[N];
const double Pi = 3.14159265358979323846264;
void fft(Complex *a, int type) {
	for (register int i = 0; i < limi; ++i)
		if (i < r[i])	swap(a[i], a[r[i]]);
	for (register int j = 1; j < limi; j <<= 1) {//長度 
		Complex T(cos(Pi/j), type * sin(Pi / j));
		for (register int k = 0; k < limi; k += (j << 1)) {//第幾塊 
			Complex t(1, 0);
			for (register int p = 0; p < j; ++p, t = t * T) {//該塊的第幾個 
				Complex nx = a[k + p], ny = t * a[k + j + p];
				a[k + p] = nx + ny;
				a[k + j + p] = nx - ny;
			}
		}
	}
}

int main() {
	read(n); read(m);
	int aa;
	for (register int i = 0; i <= n; ++i)	read(aa), A[i].x = aa;
	for (register int i = 0; i <= m; ++i)	read(aa), B[i].x = aa;
	while (limi<=n + m)	limi <<= 1, l++;
	for (register int i = 0; i < limi; ++i)
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
	fft(A, 1); fft(B, 1);
	for (register int i = 0; i <= limi; ++i)	A[i] = A[i] * B[i];
	fft(A, -1);
	for (register int i = 0; i <= n + m; ++i)
		printf("%d ", (int)(A[i].x / limi + 0.5));
	return 0;
}

快速數論變換(NTT)

快速數論變換(NTT)小結

NTT(快速數論變換)用到的各種素數及原根

素數 原根
998244353 3
3221225473(long long) 5
395 824 185 999 37 (3e13) 5

記得取模!!

\(2020.7.28\) \(Update:\)更新了程式碼

\(Code:\)

const int P = 998244353;
const int G = 3;
const int Gi = (P + 1) / G;
inline void ntt(ll *a, int type) {
	for (register int i = 1; i < limi; ++i)
		if (i < r[i])	swap(a[i], a[r[i]]);
	for (register int i = 1; i < limi; i <<= 1) {//i < limi
		ll T = quickpow(type == 1 ? G : Gi, (P - 1) / (i << 1));//Attention!!
		for (register int j = 0; j < limi; j += (i << 1)) {
			ll t = 1;
			for (register int k = 0; k < i; ++k, t = t * T % P) {//Attention!! : % P
				ll nx = a[j + k], ny = a[j + k + i] * t % P;
				a[j + k] = (nx + ny) % P;
				a[j + k + i] = (nx - ny + P) % P;
			}
		}
	}
	if (type == -1) {
		ll inv = quickpow(limi, P - 2);
		for (register int i = 0; i < limi; ++i)
			a[i] = a[i] * inv % P;
	}
}
inline void mul(ll *a, ll *b, int n, int m) {//傳入 a, b,匯出到 a
	while (limi <= (n + m))	limi <<= 1, ++L;
	for (register int i = 1; i < limi; ++i)
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));
	ntt(a, 1), ntt(b, 1);
	for (register int i = 0; i < limi; ++i)	a[i] = a[i] * b[i] % P;
	ntt(a, -1);
}
  • FFT與NTT(多項式乘法)的應用:

【模板】A*B Problem升級版(FFT快速傅立葉)

通過模擬乘法豎式,我們發現,高精乘其實就是在進行多項式乘法。這樣的話我們可以用FFT或NTT來把它優化到nlogn。

\(Code:\)

#define P 998244353
#define G 3
#define Gi 332748118
char as[N], bs[N];
int n, m;
ll A[N], B[N], ans[N];
ll limi = 1, l, inv;
int r[N];
inline ll quickpow(ll x, ll k)...
inline void ntt(ll *a, int type) {
	for (register int i = 0; i <= limi; ++i) 
		if (i < r[i])	swap(a[i], a[r[i]]);
	for (register int i = 1; i < limi; i <<= 1) {
		ll T = quickpow(type == 1 ? G : Gi, (P - 1) / (i << 1));
		for (register int j = 0; j < limi; j += (i << 1)) {
			ll t = 1;
			for (register int p = 0; p < i; ++p, t = t * T % P) {
				ll nx = a[j + p], ny = t * a[j + p + i] % P;
				a[j + p] = (nx + ny) % P;
				a[j + p + i] = (nx - ny + P) % P;
			}
		}
	}
}
int main() {
	scanf("%s%s", as, bs);
	n = strlen(as) - 1;
	m = strlen(bs) - 1;
	ll ct = 0;
	for (register int i = n; i >= 0; --i) A[ct++] = as[i] - '0';
	ct = 0;
	for (register int i = m; i >= 0; --i)	B[ct++] = bs[i] - '0';
	while (limi <= n + m)	limi <<= 1, l++;
	for (register int i = 1; i <= limi; ++i) 
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
	ntt(A, 1); ntt(B, 1);
	for (register int i = 0; i <= limi; ++i)	A[i] = A[i] * B[i] % P;
	ntt(A, -1);
	inv = quickpow(limi, P - 2);
	for (register int i = 0; i <= limi; ++i)
		ans[i] = A[i] * inv % P;
	limi += 5;
	for (register int i = 0; i <= limi; ++i)
		if (ans[i] >= 10) {
			ans[i + 1] += ans[i] / 10;
			ans[i] %= 10;
		}
	ll len = 1;
	for (register int i = limi; i >= 0; --i)
		if (ans[i]) break;
		else	len = i - 1;
	for (register int i = len; i >= 0; --i) {
		printf("%lld", ans[i]);
	}
	return 0;
}

例題

通過數學推導,我們發現,要解決其中的旋轉求最大的aibi的和的問題時,我們可以把它轉化成求卷積(多項式乘法)後的後n項的最值問題,這裡用NTT優化。但其實這道題主要還是難在數學推導的想法以及如何想到卷積。

\(Code:\)

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
#include <cmath>
#include <string>
#define N 300010
#define P 998244353
#define G 3
#define Gi 332748118
#define inf 992337203685477580ll
typedef long long ll;
template<typename T> inline void read(T &x) {
	x = 0; char c = getchar(); bool flag = false;
	while (!isdigit(c)) {if (c == '-') flag = true; c = getchar(); }
	while (isdigit(c)) {x = (x << 1) + (x << 3) + (c ^ 48); c = getchar(); }
	if (flag)	x = -x;
}
using namespace std;
ll n, m, limi = 1, l; 
ll x[N], y[N], r[N];
ll ans, sum, toans = inf;
inline ll quickpow(ll x, ll k) {
	ll res = 1;
	while (k) {
		if (k & 1)	res = res * x % P;
		x = x * x % P;
		k >>= 1;
	}
	return res;
}
inline void ntt(ll *a, int type) {
	for (register int i = 0; i <= limi; ++i)
		if (i < r[i])	swap(a[i], a[r[i]]);
	for (register int i = 1; i < limi; i <<= 1) {
		ll T = quickpow(type == 1 ? G : Gi, (P - 1) / (i << 1));
		for (register int j = 0; j < limi; j += (i << 1)) {
			ll t = 1;
			for (register int p = 0; p < i; ++p, t = t * T % P) {
				ll nx = a[j + p], ny = t * a[j + p + i] % P;
				a[j + p] = (nx + ny) % P;
				a[j + p + i] = (nx - ny + P) % P;
			}
		}
	}
	if (type == -1) {
		ll inv = quickpow(limi, P - 2);
		for (register int i = 0; i <= limi; ++i)
			a[i] = a[i] * inv % P;
	}
}
int main() {
	read(n); read(m);
	for (register int i = 1; i <= n; ++i) read(x[i]), x[i + n] = x[i];
	for (register int i = 1; i <= n; ++i)	read(y[i]);
	for (register int i = 1; i <= n; ++i) {
		ans += x[i] * x[i] + y[i] * y[i];
		sum += x[i] - y[i];
	}
	sum *= 2;
	for (register int i = -m; i <= m; ++i) {
		toans = min(toans, 1ll * n * i * i + sum * i);
	}
	ans += toans;
	
	reverse(y + 1, y + n + 1);
	while (limi <= 2 * n)	limi <<= 1, l++;
	for (register int i = 0; i < limi; ++i)
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
	ntt(x, 1); ntt(y, 1);
	for (register int i = 0; i < limi; ++i)	x[i] = x[i] * y[i] % P;
	ntt(x, -1);
	sum = 0;
	for (register int i = n + 1; i <= (n << 1); ++i)	sum = max(sum, x[i]);
	ans -= 2 * sum;
	printf("%lld\n", ans);
	return 0;
}

  • 注意:
  1. 記得取模!+1

  2. 左移和右移一定分清!!

  3. 關於i = 0還是i = 1:

FFT和NTT裡都是i = 0,別寫成i = 1。

  1. 關於<= limi還是< limi:

寫<= limi總不會錯的。

統計答案的時候不要寫<= limi!!!

第一層迴圈也不要寫 <= limi,寫 < limi

  1. 到了後面(多項式乘法時)n和m的出現次數就少了,主要是limi。

  2. cosnt int Gi = (M + 1) / G;以後就這麼寫吧,省著把332748118 寫成 322748118

  3. NTT和FFT的第三層迴圈中的p應寫成(int p = 0; p < i; ++p, t = t × T % P)。 +1

  4. 記住,是ax = a[j + p], ay = t × a[i + j + p]!!!別忘了乘t!!

  5. NTT和FFT的第一層迴圈應寫成(int i = 1; i < limi; i <<= 1)。

  6. FFT中T為Complex(cos(PI / i), sin(PI / i) * type),橫座標是cos,縱座標是sin!!

  7. 一開始蝴蝶變換的時候是swap(a[i], a[r[i]]),不是swap(i, r[i])!! +1

習題

實際上這道題應該是例題的基礎,是純的FFT。

NTT配合manacher來做。細節不少,有一定難度。