1. 程式人生 > >大型大常數多項式模板

大型大常數多項式模板

pac bit pre sin class amp struct std sta

# include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int mod(998244353);
const int inv2(499122177);
const int maxn(1 << 18);

/*
const double pi(acos(-1));

struct Complex {
    double a, b;

    inline Complex() {
        a = b = 0;
    }

    inline Complex(double _a, double _b) {
        a = _a, b = _b;
    }

    inline Complex operator +(Complex x) const {
        return Complex(a + x.a, b + x.b);
    }

    inline Complex operator -(Complex x) const {
        return Complex(a - x.a, b - x.b);
    }

    inline Complex operator *(Complex x) const {
        return Complex(a * x.a - b * x.b, a * x.b + b * x.a);
    }

    inline Complex Conj() {
        return Complex(a, -b);
    }
};
*/

inline int Pow(ll x, int y) {
    register ll ret = 1;
    for (; y; y >>= 1, x = x * x % mod)
        if (y & 1) ret = ret * x % mod;
    return ret;
}

inline void Inc(int &x, const int y) {
    if ((x += y) >= mod) x -= mod;
}

namespace FFT {

    /* all module
    Complex ma[maxn], mb[maxn], w[maxn], a1[maxn], a2[maxn];
    int r[maxn], l, len, a[maxn], b[maxn];

    inline void DFT(Complex *p, int opt) {
        register int i, j, k, t;
        register Complex wn, x, y;
        for (i = 0; i < len; ++i) if (r[i] < i) swap(p[r[i]], p[i]);
        for (i = 1; i < len; i <<= 1)
            for(t = i << 1, j = 0; j < len; j += t)
                for (k = 0; k < i; ++k) {
                    wn = w[len / i * k];
                    if (opt == -1) wn.b *= -1;
                    x = p[j + k], y = wn * p[i + j + k];
                    p[j + k] = x + y, p[i + j + k] = x - y;
                }
    }

    inline void Init(const int n) {
        register int i, x, y;
        for (l = 0, len = 1; len < n; len <<= 1) ++l;
        for (i = 0; i < len; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
        for (i = 0; i < len; ++i) a1[i] = a2[i] = ma[i] = mb[i] = Complex(0, 0), a[i] = b[i] = 0;
        for (i = 0; i < len; ++i) w[i] = Complex(cos(pi * i / len), sin(pi * i / len));
    }

    inline void Calc1() {
        register int i, k, v1, v2, v3;
        register Complex ca, cb, da1, da2, db1, db2;
        for (i = 0; i < len; ++i) ma[i] = Complex(a[i] & 32767, a[i] >> 15), mb[i] = Complex(b[i] & 32767, b[i] >> 15);
        for (DFT(ma, 1), DFT(mb, 1), i = 0; i < len; ++i) {
            k = (len - i) & (len - 1), ca = ma[k].Conj(), cb = mb[k].Conj();
            da1 = (ca + ma[i]) * Complex(0.5, 0), da2 = (ma[i] - ca) * Complex(0, -0.5);
            db1 = (cb + mb[i]) * Complex(0.5, 0), db2 = (mb[i] - cb) * Complex(0, -0.5);
            a1[i] = da1 * db1 + (da1 * db2 + da2 * db1) * Complex(0, 1), a2[i] = da2 * db2;
        }
        for (DFT(a1, -1), DFT(a2, -1), i = 0; i < len; ++i) {
            v1 = (ll)(a1[i].a / len + 0.5) % mod, v2 = (ll)(a1[i].b / len + 0.5) % mod;
            v3 = (ll)(a2[i].a / len + 0.5) % mod, a[i] = (((ll)v3 << 30) + ((ll)v2 << 15) + v1) % mod;
            if (a[i] < 0) a[i] += mod;
        }
    }

    inline void Calc2() {
        register int i, k, v1, v2, v3;
        register Complex ca, cb, da1, da2, db1, db2;
        for (i = 0; i < len; ++i) ma[i] = Complex(a[i] & 32767, a[i] >> 15), mb[i] = Complex(b[i] & 32767, b[i] >> 15);
        for (DFT(ma, 1), DFT(mb, 1), i = 0; i < len; ++i) {
            k = (len - i) & (len - 1), ca = ma[k].Conj(), cb = mb[k].Conj();
            da1 = (ca + ma[i]) * Complex(0.5, 0), da2 = (ma[i] - ca) * Complex(0, -0.5);
            db1 = (cb + mb[i]) * Complex(0.5, 0), db2 = (mb[i] - cb) * Complex(0, -0.5);
            a1[i] = da1 * db1 + (da1 * db2 + da2 * db1) * Complex(0, 1), a2[i] = da2 * db2;
        }
        for (DFT(a1, -1), DFT(a2, -1), i = 0; i < len; ++i) {
            v1 = (ll)(a1[i].a / len + 0.5) % mod, v2 = (ll)(a1[i].b / len + 0.5) % mod;
            v3 = (ll)(a2[i].a / len + 0.5) % mod, a[i] = (((ll)v3 << 30) + ((ll)v2 << 15) + v1) % mod;
            if (a[i] < 0) a[i] += mod;
        }
        for (i = 0; i < len; ++i) ma[i] = Complex(a[i] & 32767, a[i] >> 15), mb[i] = Complex(b[i] & 32767, b[i] >> 15);
        for (DFT(ma, 1), DFT(mb, 1), i = 0; i < len; ++i) {
            k = (len - i) & (len - 1), ca = ma[k].Conj(), cb = mb[k].Conj();
            da1 = (ca + ma[i]) * Complex(0.5, 0), da2 = (ma[i] - ca) * Complex(0, -0.5);
            db1 = (cb + mb[i]) * Complex(0.5, 0), db2 = (mb[i] - cb) * Complex(0, -0.5);
            a1[i] = da1 * db1 + (da1 * db2 + da2 * db1) * Complex(0, 1), a2[i] = da2 * db2;
        }
        for (DFT(a1, -1), DFT(a2, -1), i = 0; i < len; ++i) {
            v1 = (ll)(a1[i].a / len + 0.5) % mod, v2 = (ll)(a1[i].b / len + 0.5) % mod;
            v3 = (ll)(a2[i].a / len + 0.5) % mod, a[i] = (((ll)v3 << 30) + ((ll)v2 << 15) + v1) % mod;
            if (a[i] < 0) a[i] += mod;
        }
    }
    */

    int a[maxn], b[maxn], len, r[maxn], l, w[2][maxn];

    inline void Init(const int n) {
        register int i, x, y;
        for (l = 0, len = 1; len < n; len <<= 1) ++l;
        for (i = 0; i < len; ++i) r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
        for (i = 0; i < len; ++i) a[i] = b[i] = 0;
        w[1][0] = w[0][0] = 1, x = Pow(3, (mod - 1) / len), y = Pow(x, mod - 2);
        for (i = 1; i < len; ++i) w[0][i] = (ll)w[0][i - 1] * x % mod, w[1][i] = (ll)w[1][i - 1] * y % mod;
    }

    inline void NTT(int *p, const int opt) {
        register int i, j, k, wn, t, x, y;
        for (i = 0; i < len; ++i) if (r[i] < i) swap(p[r[i]], p[i]);
        for (i = 1; i < len; i <<= 1)
            for (t = i << 1, j = 0; j < len; j += t)
                for (k = 0; k < i; ++k) {
                    wn = w[opt == -1][len / t * k];
                    x = p[j + k], y = (ll)wn * p[i + j + k] % mod;
                    p[j + k] = x + y >= mod ? x + y - mod : x + y;
                    p[i + j + k] = x - y < 0 ? x - y + mod : x - y;
                }
        if (opt == -1) for (wn = Pow(len, mod - 2), i = 0; i < len; ++i) p[i] = (ll)p[i] * wn % mod;
    }

    inline void Calc1() {
        register int i;
        NTT(a, 1), NTT(b, 1);
        for (i = 0; i < len; ++i) a[i] = (ll)a[i] * b[i] % mod;
        NTT(a, -1);
    }

    inline void Calc2() {
        register int i;
        NTT(a, 1), NTT(b, 1);
        for (i = 0; i < len; ++i) a[i] = (ll)a[i] * b[i] % mod * b[i] % mod;
        NTT(a, -1);
    }
}

struct Poly {
    vector <int> v;

    inline Poly() {
        v.resize(1);
    }

    inline Poly(const int d) {
        v.resize(d);
    }

    inline int Length() const {
        return v.size();
    }
    
    inline void Adjust() {
        register int n = v.size(), len;
        for (len = 1; len < n; len <<= 1);
        v.resize(len);
    }

    inline Poly operator +(Poly b) const {
        register int i, l1 = Length(), l2 = b.Length(), l3 = max(l1, l2);
        register Poly c(l3);
        for (i = 0; i < l1; ++i) c.v[i] = v[i];
        for (i = 0; i < l2; ++i) Inc(c.v[i], b.v[i]);
        return c;
    }

    inline Poly operator -(Poly b) const {
        register int i, l1 = Length(), l2 = b.Length(), l3 = max(l1, l2);
        register Poly c(l3);
        for (i = 0; i < l1; ++i) c.v[i] = v[i];
        for (i = 0; i < l2; ++i) Inc(c.v[i], mod - b.v[i]);
        return c;
    }

    inline void InvMul(Poly b) {
        register int i, l1 = Length(), l2 = b.Length(), l3 = l1 + l2 - 1;
        FFT :: Init(l3);
        for (i = 0; i < l1; ++i) FFT :: a[i] = v[i];
        for (i = 0; i < l2; ++i) FFT :: b[i] = b.v[i];
        FFT :: Calc2();
    }

    inline Poly operator *(Poly b) const {
        register int i, l1 = Length(), l2 = b.Length(), l3 = l1 + l2 - 1;
        register Poly c(l3);
        FFT :: Init(l3);
        for (i = 0; i < l1; ++i) FFT :: a[i] = v[i];
        for (i = 0; i < l2; ++i) FFT :: b[i] = b.v[i];
        FFT :: Calc1();
        for (i = 0; i < l3; ++i) c.v[i] = FFT :: a[i];
        return c;
    }

    inline Poly operator *(int b) const {
        register int i, l = Length();
        register Poly c(l);
        for (i = 0; i < l; ++i) c.v[i] = (ll)v[i] * b % mod;
        return c;
    }

    inline int Calc(const int x) {
        register int i, ret = v[0], l = Length(), now = x;
        for (i = 1; i < l; ++i) Inc(ret, (ll)now * v[i] % mod), now = (ll)now * x % mod;
        return ret;
    }
};

inline void Calc(Poly p, Poly &q, int len) {
    register int i;
    for (i = len - 1; i; --i) q.v[i] = (ll)p.v[i - 1] * Pow(i, mod - 2) % mod;
    q.v[0] = 0;
}

inline void ICalc(Poly p, Poly &q, int len) {
    register int i;
    for (i = len - 2; ~i; --i) q.v[i] = (ll)p.v[i + 1] * (i + 1) % mod;
    q.v[len - 1] = 0;
}

void Inv(Poly p, Poly &q, int len) {
    if (len == 1) {
        q.v[0] = Pow(p.v[0], mod - 2);
        return;
    }
    Inv(p, q, len >> 1);
    register int i;
    p.InvMul(q);
    for (i = 0; i < len; ++i) q.v[i] = ((ll)2 * q.v[i] + mod - FFT :: a[i]) % mod;
}

void Ln(Poly p, Poly &q, int len) {
    static Poly c, a;
    c.v.resize(len), a.v.resize(len);
    Inv(p, c, len), ICalc(p, a, len);
    c = c * a, c.v.resize(len), Calc(c, q, len);
}

void Exp(Poly p, Poly &q, int len) {
    if (len == 1) {
        q.v[0] = 1;
        return;
    }
    static Poly d;
    Exp(p, q, len >> 1), q.v.resize(len);
    d.v.resize(len), Ln(q, d, len), Inc(d.v[0], mod - 1);
    d = p - d, d.v.resize(len), q = q * d, q.v.resize(len);
}

void Sqrt(Poly p, Poly &q, int len) {
    if (len == 1) {
        q.v[0] = sqrt(p.v[0]);
        return;
    }
    static Poly c, a;
    Sqrt(p, q, len >> 1), c.v.resize(len), Inv(q, c, len);
    a = p, a.v.resize(len), a = a * c, a.v.resize(len);
    q = (q + a) * inv2, q.v.resize(len);
}

inline Poly operator %(const Poly &a, const Poly &b) {
    if (a.Length() < b.Length()) return a;
    register Poly x = a, y = b, z;
    register int n = a.Length(), m = b.Length(), res = n - m + 1;
    reverse(x.v.begin(), x.v.end()), reverse(y.v.begin(), y.v.end());
    x.v.resize(res), y.v.resize(res), y.Adjust();
    z.v.resize(y.Length()), Inv(y, z, y.Length());
    z.v.resize(res), x = x * z;
    x.v.resize(res), reverse(x.v.begin(), x.v.end());
    y = a - x * b, y.v.resize(m - 1);
    return y;
}

Poly f[maxn], a, b;
int n, m, x[maxn], y[maxn], ans[maxn];

void Build(int o, int l, int r) {
    if (l == r) {
        f[o].v.resize(2), f[o].v[0] = mod - x[l], f[o].v[1] = 1;
        return;
    }
    register int mid = (l + r) >> 1;
    Build(o << 1, l, mid), Build(o << 1 | 1, mid + 1, r);
    f[o] = f[o << 1] * f[o << 1 | 1];
}

void Solve_val(Poly cur, int o, int l, int r) {
    if (r - l + 1 <= 2000) {
        for (; l <= r; ++l) ans[l] = 1LL * y[l] * Pow(cur.Calc(x[l]), mod - 2) % mod;
        return;
    }
    register int mid = (l + r) >> 1;
    Solve_val(cur % f[o << 1], o << 1, l, mid);
    Solve_val(cur % f[o << 1 | 1], o << 1 | 1, mid + 1, r);
}

void Solve(Poly &cur, int o, int l, int r) {
    if (l == r) {
        cur.v[0] = ans[l];
        return;
    }
    register int mid = (l + r) >> 1;
    register Poly lp(mid - l + 1), rp(r - mid);
    Solve(lp, o << 1, l, mid);
    Solve(rp, o << 1 | 1, mid + 1, r);
    cur = lp * f[o << 1 | 1] + rp * f[o << 1];
}

inline void Lagrange() {
    register int i, len;
    scanf("%d", &n);
    for (i = 1; i <= n; ++i) scanf("%d%d", &x[i], &y[i]);
    Build(1, 1, n), a = f[1], len = a.Length();
    for (i = 0; i < len - 1; ++i) a.v[i] = (ll)a.v[i + 1] * (i + 1) % mod;
    if (a.Length() > 1) a.v.pop_back();
    else a.v[0] = 0;
    b.v.resize(n), Solve_val(a, 1, 1, n), Solve(b, 1, 1, n);
    for (i = 0; i < n; ++i) printf("%d ", b.v[i]);
    puts("");
}

int main() {
    return 0;
}

大型大常數多項式模板