1. 程式人生 > >BZOJ 3992 [SDOI2015]序列統計 NTT

BZOJ 3992 [SDOI2015]序列統計 NTT

題意:
存在一個集合S,求長度為N,每一個元素都是S中的元素(可重複),並且該序列所有數的乘積mod M = x 的序列個數。
M是質數,且集合中的所有元素的範圍都在[0,M-1]內。
並且x!=0
解析:
因為有M是質數這個特殊條件,所以我們可以求出來M的原根G,之後因為G的0~(phi(M)-1)可以完美替代0~M-1中的數,於是我們可以考慮把S中所有的數用G的幾次冪來代替。
至於為什麼這樣考慮。
因為這樣就把我們所需要的乘法轉化成了冪的加法。
搞出集合S的生成函式。
由於每個數可以選取多次,所以接下來的問題就是S的生成函式的n次冪的對應的第x次冪項。
我們發現過程其實就是多項式的乘積過程,並且題目要求答案mod 一個原根為3的大質數,所以我們可以考慮用NTT來優化這一過程。
需要注意的是,在多項式乘積的時候,我們每一次要把大於m的係數加到其mod m後的那一項上,也就是說,不要直接消除,而是在乘積的時候把越界的部分轉到mod m下。
總複雜度O(lognmlogm)
程式碼:

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define mod 1004535809
#define G 3
#define N 262145
using namespace std;
typedef long long ll;
int n,m,x,s,root;
ll prime[20010];
int pos[17010];
ll a[N],b[N];
int rev[N];
int num[17010];
int tot;
ll mm;
ll quick_my(ll x
,ll y,ll MOD) { ll ret=1; while(y) { if(y&1)ret=ret*x%MOD; x=x*x%MOD; y>>=1; } return ret; } void get_factor(ll x) { tot=0; for(ll i=2;i*i<=x;i++) { if(x%i==0) { prime[++tot]=i; while(x%i==0)x/=i; } } if
(x!=1)prime[++tot]=x; } bool check(ll x,ll MOD,ll PHI) { for(int i=1;i<=tot;i++) { if(quick_my(x,PHI/prime[i],MOD)==1)return 0; } return 1; } int find_primitive_root(ll x) { ll tmp=x-1; get_factor(tmp); for(int i=2;;i++) { if(check(i,x,tmp))return i; } } void NTT(ll *a,int f) { for(int i=0;i<m;i++)if(i<rev[i])swap(a[i],a[rev[i]]); for(int h=2;h<=m;h<<=1) { ll wn=quick_my(G,(mod-1)/h,mod); for(int i=0;i<m;i+=h) { ll w=1; for(int j=0;j<(h>>1);j++,w=w*wn%mod) { ll t=w*a[i+j+(h>>1)]%mod; a[i+j+(h>>1)]=((a[i+j]-t)%mod+mod)%mod; a[i+j]=(a[i+j]+t)%mod; } } } if(f==-1) { for(int i=1;i<(m>>1);i++)swap(a[i],a[m-i]); ll inv=quick_my(m,mod-2,mod); for(int i=0;i<m;i++)a[i]=a[i]*inv%mod; } } ll ret[N]; void get_my(int y) { ret[0]=1; while(y) { NTT(b,1); if(y&1) { NTT(ret,1); for(int i=0;i<m;i++)ret[i]=ret[i]*b[i]%mod; NTT(ret,-1); for(int i=m-1;i>=mm-1;i--)ret[i-mm+1]=(ret[i-mm+1]+ret[i])%mod,ret[i]=0; } for(int i=0;i<m;i++)b[i]=b[i]*b[i]%mod; NTT(b,-1); for(int i=m-1;i>=mm-1;i--)b[i-mm+1]=(b[i-mm+1]+b[i])%mod,b[i]=0; y>>=1; } } int main() { scanf("%d%d%d%d",&n,&m,&x,&s); for(int i=1;i<=s;i++)scanf("%d",&num[i]); root=find_primitive_root(m); ll tmp=1; for(int i=0;i<m-1;i++) pos[tmp]=i,tmp=tmp*root%m; int l=m*2,L=0; mm=m; for(m=1;m<=l;m<<=1)L++; for(int i=0;i<m;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1)); for(int i=1;i<=s;i++) if(num[i]!=0)b[pos[num[i]]]++; get_my(n); printf("%lld\n",ret[pos[x]]); }