BZOJ 3992 [SDOI2015]序列統計 NTT
阿新 • • 發佈:2018-12-24
題意:
存在一個集合S,求長度為N,每一個元素都是S中的元素(可重複),並且該序列所有數的乘積mod M = x 的序列個數。
M是質數,且集合中的所有元素的範圍都在[0,M-1]內。
並且x!=0
解析:
因為有M是質數這個特殊條件,所以我們可以求出來M的原根G,之後因為G的0~(phi(M)-1)可以完美替代0~M-1中的數,於是我們可以考慮把S中所有的數用G的幾次冪來代替。
至於為什麼這樣考慮。
因為這樣就把我們所需要的乘法轉化成了冪的加法。
搞出集合S的生成函式。
由於每個數可以選取多次,所以接下來的問題就是S的生成函式的n次冪的對應的第x次冪項。
我們發現過程其實就是多項式的乘積過程,並且題目要求答案mod 一個原根為3的大質數,所以我們可以考慮用NTT來優化這一過程。
需要注意的是,在多項式乘積的時候,我們每一次要把大於m的係數加到其mod m後的那一項上,也就是說,不要直接消除,而是在乘積的時候把越界的部分轉到mod m下。
總複雜度O(lognmlogm)
程式碼:
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define mod 1004535809
#define G 3
#define N 262145
using namespace std;
typedef long long ll;
int n,m,x,s,root;
ll prime[20010];
int pos[17010];
ll a[N],b[N];
int rev[N];
int num[17010];
int tot;
ll mm;
ll quick_my(ll x ,ll y,ll MOD)
{
ll ret=1;
while(y)
{
if(y&1)ret=ret*x%MOD;
x=x*x%MOD;
y>>=1;
}
return ret;
}
void get_factor(ll x)
{
tot=0;
for(ll i=2;i*i<=x;i++)
{
if(x%i==0)
{
prime[++tot]=i;
while(x%i==0)x/=i;
}
}
if (x!=1)prime[++tot]=x;
}
bool check(ll x,ll MOD,ll PHI)
{
for(int i=1;i<=tot;i++)
{
if(quick_my(x,PHI/prime[i],MOD)==1)return 0;
}
return 1;
}
int find_primitive_root(ll x)
{
ll tmp=x-1;
get_factor(tmp);
for(int i=2;;i++)
{
if(check(i,x,tmp))return i;
}
}
void NTT(ll *a,int f)
{
for(int i=0;i<m;i++)if(i<rev[i])swap(a[i],a[rev[i]]);
for(int h=2;h<=m;h<<=1)
{
ll wn=quick_my(G,(mod-1)/h,mod);
for(int i=0;i<m;i+=h)
{
ll w=1;
for(int j=0;j<(h>>1);j++,w=w*wn%mod)
{
ll t=w*a[i+j+(h>>1)]%mod;
a[i+j+(h>>1)]=((a[i+j]-t)%mod+mod)%mod;
a[i+j]=(a[i+j]+t)%mod;
}
}
}
if(f==-1)
{
for(int i=1;i<(m>>1);i++)swap(a[i],a[m-i]);
ll inv=quick_my(m,mod-2,mod);
for(int i=0;i<m;i++)a[i]=a[i]*inv%mod;
}
}
ll ret[N];
void get_my(int y)
{
ret[0]=1;
while(y)
{
NTT(b,1);
if(y&1)
{
NTT(ret,1);
for(int i=0;i<m;i++)ret[i]=ret[i]*b[i]%mod;
NTT(ret,-1);
for(int i=m-1;i>=mm-1;i--)ret[i-mm+1]=(ret[i-mm+1]+ret[i])%mod,ret[i]=0;
}
for(int i=0;i<m;i++)b[i]=b[i]*b[i]%mod;
NTT(b,-1);
for(int i=m-1;i>=mm-1;i--)b[i-mm+1]=(b[i-mm+1]+b[i])%mod,b[i]=0;
y>>=1;
}
}
int main()
{
scanf("%d%d%d%d",&n,&m,&x,&s);
for(int i=1;i<=s;i++)scanf("%d",&num[i]);
root=find_primitive_root(m);
ll tmp=1;
for(int i=0;i<m-1;i++)
pos[tmp]=i,tmp=tmp*root%m;
int l=m*2,L=0;
mm=m;
for(m=1;m<=l;m<<=1)L++;
for(int i=0;i<m;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1));
for(int i=1;i<=s;i++)
if(num[i]!=0)b[pos[num[i]]]++;
get_my(n);
printf("%lld\n",ret[pos[x]]);
}