1. 程式人生 > 實用技巧 >FFT | NTT

FFT | NTT

FFT:

#include <bits/stdc++.h>
#define ll long long
#define inf 0x3f3f3f3f
#define db double
#define Pi acos(-1)
#define eps 1e-8
#define N 400005
using namespace std;
template <typename T> T sqr(T x) { return x * x; }
struct Complex {
    db x, y;
    Complex (db xx = 0, db yy = 0) { x = xx; y = yy; }
    Complex operator + (Complex B) { return Complex(x + B.x, y + B.y); }
    Complex operator - (Complex B) { return Complex(x - B.x, y - B.y); }
    Complex operator * (Complex B) { return Complex(x * B.x - y * B.y, x * B.y + y * B.x); }
}a[N], b[N], c[N], d[N];
int r[N];
int n, m, limit, l;
void FFT(Complex *a, int type) {
    for (int i = 0; i < limit; i ++) if (i < r[i]) swap(a[i], a[r[i]]);
    for (int mid = 1; mid < limit; mid <<= 1) {
        Complex wn(cos(Pi / mid), type * sin(Pi / mid));
        for (int R = mid << 1, j = 0; j < limit; j += R) {
            Complex w(1, 0);
            for (int k = 0; k < mid; k ++, w = w * wn) {
                Complex x = a[j + k], y = w * a[j + mid + k];
                a[j + k] = x + y; a[j + mid + k] = x - y;
            }
        } 
    }
    if (type == -1) for (int i = 0; i < limit; i ++) a[i].x /= 1.0 * limit;
    // if( type == -1 ) for(int i=0;i<limit;i++)a[i].x = (int)(a[i].x/limit+0.5);
}
int main() {
    scanf("%d",&n);
    for (int i = 1; i <= n; i ++) {
        db x; scanf("%lf", &x);
        a[i].x = c[n - i + 1].x = x; b[i].x = 1.0 / sqr(i * 1.0);
    }
    limit = 1; while (limit <= (n << 1)) limit <<= 1, l ++;
    for (int i = 0; i < limit; i ++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
    FFT(a, 1); FFT(b, 1);FFT(c, 1);
    for (int i = 0; i < limit; i ++) a[i] = a[i] * b[i];
    FFT(a, -1);
    for (int i = 0; i < limit; i ++) c[i] = c[i] * b[i];
    FFT(c, -1);
    for (int i = 1; i <= n; i ++) printf("%.3lf\n", a[i].x - c[n - i + 1].x);
    return 0;
}

NTT

const int MAXN = 3 * 1e6 + 10, P = 998244353, G = 3, Gi = 332748118; 
int N, M, limit = 1, L, r[MAXN];
LL a[MAXN], b[MAXN];
inline LL fastpow(LL a, LL k) {
	LL base = 1;
	while(k) {
		if(k & 1) base = (base * a ) % P;
		a = (a * a) % P;
		k >>= 1;
	}
	return base % P;
}
inline void NTT(LL *A, int type) {
	for(int i = 0; i < limit; i++) 
		if(i < r[i]) swap(A[i], A[r[i]]);
	for(int mid = 1; mid < limit; mid <<= 1) {	
		LL Wn = fastpow( type == 1 ? G : Gi , (P - 1) / (mid << 1));
		for(int j = 0; j < limit; j += (mid << 1)) {
			LL w = 1;
			for(int k = 0; k < mid; k++, w = (w * Wn) % P) {
				 int x = A[j + k], y = w * A[j + k + mid] % P;
				 A[j + k] = (x + y) % P,
				 A[j + k + mid] = (x - y + P) % P;
			}
		}
	}
    if( type == -1 ){
        LL inv = fastpow(limit,P-2);
        for(int i=0;i<limit;i++)
            A[i] = A[i]*inv%P;
    }
}
int main() {
	N = read(); M = read();
	for(int i = 0; i <= N; i++) a[i] = (read() + P) % P;
	for(int i = 0; i <= M; i++) b[i] = (read() + P) % P;
	while(limit <= N + M) limit <<= 1, L++;
	for(int i = 0; i < limit; i++) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (L - 1));	
	NTT(a, 1);NTT(b, 1);	
	for(int i = 0; i < limit; i++) a[i] = (a[i] * b[i]) % P;
	NTT(a, -1);	
	for(int i = 0; i <= N + M; i++)
		printf("%lld ", a[i]);
	return 0;
}