bzoj 3625 小朋友和二叉樹 多項式開根
阿新 • • 發佈:2018-12-24
常數大到飛起。
O(nlogn)的演算法在CF上跑了2000ms也是神奇。
有空看下怎麼常數寫小一點。。
NTT做了個小優化,快了一點
#include <iostream> #include <cstdio> #include <cstring> #include <cmath> #include <algorithm> #include <queue> #include <set> #include <ctime> #include <cstdlib> #include <tr1/unordered_map> using namespace std; using namespace std::tr1; #define N 600020 #define LL long long #define ls (i << 1) #define rs (ls | 1) #define md ((ll + rr) >> 1) #define lson ll, md, ls #define rson md + 1, rr, rs #define inf 0x3f3f3f3f #define K 3 const int P = 998244353; const int pRoot = 3; int qpow(int x, int k, int p) { int ret = 1; while(k) { if(k & 1) ret = 1LL * ret * x % p; k >>= 1; x = 1LL * x * x % p; } return ret; } int wn[25]; int t1[N], t2[N], t3[N], t4[N]; int inv2 = qpow(2, P - 2, P); void getWn() { for(int i = 1; i <= 21; ++i) { int t = 1 << i; wn[i] = qpow(pRoot, (P - 1) / t, P); } } int rev[N], wnPow[N]; void change(int y[], int len) { for(int i = 0; i < len; ++i) { rev[i] = (rev[i>>1] >> 1) + (i & 1) * len / 2; if(i < rev[i]) swap(y[i], y[rev[i]]); } } void FFT(int y[], int len, int on) { change(y, len); int id = 0; for(int h = 2; h <= len; h <<= 1) { ++id; wnPow[0] = 1; for(int j = 1; j < h / 2; ++j) wnPow[j] = 1LL * wnPow[j-1] * wn[id] % P; for(int j = 0; j < len; j += h) { for(int k = j; k < j + h / 2; ++k) { int u = y[k]; int t = 1LL * wnPow[k-j] * y[k+h/2] % P; y[k] = u + t; if(y[k] >= P) y[k] -= P; y[k+h/2] = u - t; if(y[k+h/2] < 0) y[k+h/2] += P; } } } if(on == -1) { for(int i = 1; i < len / 2; ++i) swap(y[i], y[len-i]); int inv = qpow(len, P - 2, P); for(int i = 0; i < len; ++i) { y[i] = 1LL * y[i] * inv % P; } } } void mul(int x[], int y[], int len) { FFT(x, len, 1); FFT(y, len, 1); for(int i = 0; i < len; ++i) x[i] = 1LL * x[i] * y[i] % P; FFT(x, len, -1); } void getInv(int A[], int A0[], int k) { if(k == 1) { A0[0] = qpow(A[0], P - 2, P); return; } getInv(A, A0, k / 2); for(int i = 0; i < 2 * k; ++i) { if(i < k) t3[i] = A[i]; else t3[i] = 0; } for(int i = k / 2; i < 2 * k; ++i) A0[i] = 0; FFT(t3, 2 * k, 1); FFT(A0, 2 * k, 1); for(int i = 0; i < 2 * k; ++i) { t3[i] = 2 - 1LL * t3[i] * A0[i] % P; if(t3[i] < 0) t3[i] += P; A0[i] = 1LL * A0[i] * t3[i] % P; } FFT(A0, 2 * k, -1); } void getSqrt(int A[], int A0[], int k) { if(k == 1) { A0[0] = 1; return; } getSqrt(A, A0, k / 2); for(int i = k / 2; i < 2 * k; ++i) A0[i] = 0; getInv(A0, t1, k); for(int i = k; i < 2 * k; ++i) t1[i] = 0; for(int i = 0; i < 2 * k; ++i) { if(i < k) t2[i] = A[i]; else t2[i] = 0; } FFT(A0, 2 * k, 1); FFT(t1, 2 * k, 1); FFT(t2, 2 * k, 1); for(int i = 0; i < 2 * k; ++i) { t1[i] = 1LL * t1[i] * t2[i] % P; A0[i] += t1[i]; if(A0[i] >= P) A0[i] -= P; } FFT(A0, 2 * k, -1); for(int i = 0; i < k; ++i) A0[i] = 1LL * A0[i] * inv2 % P; } int n, m, c[N], d[N]; void debug() { int x[8] = {1, 1, 0, 0}; int y[8] = {1, 1, 0, 1}; getSqrt(x, y, 2); for(int i = 0; i < 8; ++i) { printf("%d ", y[i] * 2 % P); } puts(""); } int main() { getWn(); scanf("%d%d", &n, &m); for(int i = 1; i <= n; ++i) { int v; scanf("%d", &v); if(v <= m) c[v] = 1; } int len = 1; while(len <= m) len <<= 1; for(int i = 0; i < len; ++i) { c[i] = - 4 * c[i]; if(c[i] < 0) c[i] += P; } c[0]++; getSqrt(c, d, len); d[0]++; if(d[0] >= P) d[0] -= P; getInv(d, c, len); for(int i = 0; i < len; ++i) { c[i] *= 2; if(c[i] >= P) c[i] -= P; } for(int i = 1; i <= m; ++i) printf("%d\n", c[i]); return 0; }