1. 程式人生 > >3992: [SDOI2015]序列統計

3992: [SDOI2015]序列統計

tor stream 乘法 ons 每次 單位 dot 長度 我們

3992: [SDOI2015]序列統計

鏈接

分析:

  給定一個集和s,求多少個長度為n的序列,滿足序列中每個數都屬於s,並且所有數的乘積模m等於x。

  設$f=\sum\limits_{i=0}^{n - 1} a_i x ^ i \ \ 如果集合中存在i,a_i = 1$

  那麽答案的生成函數為f自乘n次,這裏可以快速冪。這裏"乘法"定義是:設多項式a乘多項式b等於c,$\sum\limits_{k=0}^{n - 1} c_k = \sum\limits_{i \times j = k} a_i \times b_j$ 每次“乘法”的復雜度是$m^2$,所以復雜度是$O(m^2logn)$。

  考慮優化“乘法”的部分,我們知道多項式乘法利用FFT/NTT可以做到$nlogn$的,看能否轉化為多項式乘法,即多項式乘法的定義變為$\sum\limits_{k=0}^{n - 1} c_k = \sum\limits_{i + j = k} a_i \times b_j$。

  NTT中,有引入原根的概念,在NTT中,原根的用途相當於單位根。 原根有一個性質:對於mod p下的原根g,$g^1, g^2 \dots g^{p - 1}$互不相同,$g^{p - 1} \equiv 1 \mod p$。而且$g^1, g^2 \dots g^{p - 1}$可以分別表示$1,2 \dots p - 1$。

  那麽我們對m求出單位根,集合S中出現的每個數,都可以表示為$s_i = g^{t_{s_i}}$

  此時對於原來的一個序列y,$\prod y_i = x \mod m$,就變成了$\prod g ^{t_{y_i}} = g^{t_x} \mod m$,即$\sum t_{y_i} = x \mod m - 1$

  現在我們求的就是長度為n的序列,序列中每個數都屬於集合t,並且所有數的和模(m-1)等於x 如此按照上面的做法,將乘法的定義改為多項式乘法的定義,快速冪+NTT即可復雜度$mlogmlogn$。

  註意:多項式乘法中是沒有取模的,而這裏(i+j)%(m-1),直接將數組加倍,然後NTT完後,大於等於m的加到相應的模m後的位置上即可。

代碼:

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#include<cmath>
#include<cctype>
#include<set>
#include<queue>
#include<vector>
#include<map>
using namespace std;
typedef long long LL;

inline int read() {
    int x=0,f=1;char ch=getchar();for(;!isdigit(ch);ch=getchar())if(ch==-)f=-1;
    for(;isdigit(ch);ch=getchar())x=x*10+ch-0;return x*f;
}

const int mod = 1004535809;
const int N = 20000;
int vis[N], rev[N], n = 1, m;
int f[N], g[N], a[N], b[N], inv;

int ksm(int a,int b,int p) {
    a %= p;
    int ans = 1;
    while (b) {
        if (b & 1) ans = 1ll * ans * a % p;
        a = 1ll * a * a % p;
        b >>= 1;
    }
    return ans % p;
}
int Calc(int x) {
    if (x == 2) return 1;
    for (int i = 2; ; ++i) {
        bool flag = 1;
        for (int j = 2; j * j < x; ++j) 
            if (ksm(i, (x - 1) / j, x) == 1) { flag = false; break; }
        if (flag) return i;
    }
}
void NTT(int *a,int n,int ty) {
    for (int i = 0; i < n; ++i) if (i < rev[i]) swap(a[i], a[rev[i]]);
    for (int m = 2; m <= n; m <<= 1) {
        int w1 = ksm(3, (mod - 1) / m, mod);
        if (ty == -1) w1 = ksm(w1, mod - 2, mod);
        for (int i = 0; i < n; i += m) {
            int w = 1;
            for (int k = 0; k < (m >> 1); ++k) {
                int u = a[i + k], t = 1ll * w * a[i + k + (m >> 1)] % mod;
                a[i + k] = (u + t) % mod;
                a[i + k + (m >> 1)] = (u - t + mod) % mod;
                w = 1ll * w * w1 % mod;                
            }
        }
    }
}
void mul(int *g,int *f) {
    for (int i = 0; i < n; ++i) a[i] = g[i] % mod, b[i] = f[i] % mod;
    NTT(a, n, 1);
    NTT(b, n, 1);
    for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * b[i] % mod;
    NTT(a, n, -1);
    for (int i = 0; i < n; ++i) a[i] = 1ll * a[i] * inv % mod;
    for (int i = 0; i < m - 1; ++i) g[i] = (a[i] + a[i + m - 1]) % mod;    
}
void solve(int b) {
    inv = ksm(n, mod - 2, mod);
    g[0] = 1;
    while (b) {
        if (b & 1) mul(g, f);
        b >>= 1;
        mul(f, f);
    }
}
int main() {
    int cnt = read(); m = read(); int x = read(), s = read();
    for (int i = 1; i <= s; ++i) vis[read()] = 1;
    int q = Calc(m), pos = -1, L = 0;
    for (int i = 0, j = 1; i < m - 1; ++i, j = 1ll * j * q % m) {
        if (vis[j]) f[i] = 1;
        if (j == x) pos = i;
    }
    int M = (m - 1) * 2;
    while (n < M) n <<= 1, L ++;
    for (int i = 0; i < n; ++i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << (L - 1));
    solve(cnt);
    if (pos != -1) cout << g[pos] % mod;
    else cout << 0;
    return 0;
}

3992: [SDOI2015]序列統計