NTT板子
阿新 • • 發佈:2019-03-02
operator complex ++ [] sizeof mem sub str using
不說別的。
這份NTT跑得比FFT快,不知道為什麽。
以下代碼針對\(10^5\)的數據範圍。
#include<cstdio> #include<vector> #include<algorithm> #include<cstring> using namespace std; inline int read() { int a = 0, c = getchar(), w = 1; for(; c < '0' || c > '9'; c = getchar()) if(c == '-') w = -1; for(; c >= '0' && c <= '9'; c = getchar()) a = a * 10 + c - '0'; return a * w; } const int md = 998244353, gmd = 3; inline int add(int x, int y) { x += y; return x >= md ? x - md : x; } inline void Add(int& x, int y) { x += y; if(x >= md) x -= md; } inline int sub(int x, int y) { x -= y; return x < 0 ? x + md : x; } inline int mul(int x, int y) { return (long long)x*y%md; } inline int qpow(int a, int x) { int ret = 1; while(x) { if(x&1) ret = mul(ret, a); a = mul(a, a); x >>= 1; } return ret; } inline int inv(int x) { return qpow(x, md-2); } const int maxn = 1<<17; int w[2][1<<19], invn[1<<18]; void nttinit() { for(int i = 0; i <= 18; i++) { w[1][1<<i] = w[0][1<<i] = 1; int wn = qpow(gmd, (md-1)/(1<<i+1)), invwn = inv(wn); for(int j = (1<<i)+1; j < 1<<i+1; j++) { w[1][j] = mul(w[1][j-1], wn); w[0][j] = mul(w[0][j-1], invwn); } } for(int i = 1; i <= 1<<18; i <<= 1) invn[i] = inv(i); } void ntt(int a[], int n, bool typ) { for(int i = 1, j = n>>1; i < n; i++) { if(i < j) swap(a[i], a[j]); for(int k = n>>1; (j^=k) < k; k >>= 1); } for(int i = 1; i < n; i <<= 1) for(int j = 0; j < n; j += i<<1) for(int k = 0; k < i; k++) { int u = a[j+k], v = mul(w[typ][i+k], a[j+i+k]); a[j+k] = add(u, v); a[j+i+k] = sub(u, v); } if(!typ) for(int i = 0; i < n; i++) a[i] = mul(a[i], invn[n]); } int tmp[maxn<<1]; void Mul(int a[], int an, int b[], int bn) { if(an <= 48 || bn <= 48) { memset(tmp, 0, (an+bn-1)*sizeof(int)); for(int i = 0; i < an; i++) for(int j = 0; j < bn; j++) Add(tmp[i+j], mul(a[i], b[j])); memcpy(a, tmp, (an+bn-1)*sizeof(int)); return; } int n = 1; while(n < an+bn-1) n <<= 1; ntt(a, n, 1); ntt(b, n, 1); for(int i = 0; i < n; i++) a[i] = mul(a[i], b[i]); ntt(a, n, 0); } int n, m; int a[maxn<<1], b[maxn<<1]; int main() { n = read(); m = read(); nttinit(); for(int i = 0; i < n+1; i++) a[i] = read(); for(int i = 0; i < m+1; i++) b[i] = read(); Mul(a, n+1, b, m+1); for(int i = 0; i < n+m+1; i++) printf("%d ", a[i]); printf("\n"); return 0; }
1640ms。
#include<cstdio> #include<algorithm> #include<cstring> #include<cmath> using namespace std; inline int read() { int a = 0, c = getchar(), w = 1; for(; c < '0' || c > '9'; c = getchar()) if(c == '-') w = -1; for(; c >= '0' && c <= '9'; c = getchar()) a = a * 10 + c - '0'; return a * w; } const double pi = 3.14159265358979323846264338327950288419716939937510582097494459230781640628620899; struct complex { double r, v; complex() {} complex(double rr, double vv) {r = rr; v = vv;} }; inline complex operator + (complex a, complex b) { return complex(a.r+b.r, a.v+b.v); } inline complex operator - (complex a, complex b) { return complex(a.r-b.r, a.v-b.v); } inline complex operator - (complex x) { return complex(-x.r, -x.v); } inline complex operator * (complex a, complex b) { return complex(a.r*b.r-a.v*b.v, a.r*b.v+a.v*b.r); } const int maxn = 1<<17; void fft(complex a[], int n, bool typ) { for(int i = 1, j = n>>1; i < n; i++) { if(i < j) swap(a[i], a[j]); for(int k = n>>1; (j^=k) < k; k >>= 1); } for(int i = 1; i < n; i <<= 1) for(int j = 0; j < n; j += i<<1) { complex w = complex(1, 0), wn = complex(cos(pi/(double)i), (typ?1:-1)*sin(pi/(double)i)); for(int k = 0; k < i; k++) { complex u = a[j+k], v = w * a[j+i+k]; a[j+k] = u + v; a[j+i+k] = u - v; w = w * wn; } } if(!typ) for(int i = 0; i < n; i++) a[i].r /= n; } complex tmp[maxn<<1]; void Mul(complex a[], int an, complex b[], int bn) { if(an <= 48 || bn <= 48) { for(int i = 0; i < an+bn-1; i++) tmp[i] = complex(0, 0); for(int i = 0; i < an; i++) for(int j = 0; j < bn; j++) tmp[i+j] = tmp[i+j] + a[i] * b[j]; for(int i = 0; i < an+bn-1; i++) a[i] = tmp[i]; return; } int n = 1; while(n < an+bn-1) n <<= 1; fft(a, n, 1); fft(b, n, 1); for(int i = 0; i < n; i++) a[i] = a[i] * b[i]; fft(a, n, 0); } int n, m; complex a[maxn<<1], b[maxn<<1]; int main() { n = read(); m = read(); for(int i = 0; i < n+1; i++) a[i] = complex(read(), 0); for(int i = 0; i < m+1; i++) b[i] = complex(read(), 0); Mul(a, n+1, b, m+1); for(int i = 0; i < n+m+1; i++) printf("%d ", int(a[i].r+0.5)); printf("\n"); return 0; }
1919ms。
NTT板子