1. 程式人生 > >bzoj 3992 [SDOI2015] 序列統計 —— NTT (迴圈卷積+快速冪)

bzoj 3992 [SDOI2015] 序列統計 —— NTT (迴圈卷積+快速冪)

題目:https://www.lydsy.com/JudgeOnline/problem.php?id=3992

首先,如果把方案數和乘積分別放在係數和次數上,就可以用多項式做了;

方案數放在係數上好說,但次數是相加的,如何表示乘積?

考慮乘積與加法的關係 —— 冪的相乘就是指數相加;

所以可以找出乘積的模數 m 的原根,用其次數相加代表乘積,這個次數好像被稱為“指標”;

構造出多項式,由於要取模,所以用 NTT 做;

也就是要把初始的多項式做 n 次冪,可以用快速冪,但注意累乘起來的是係數而不是點值;

指標從0開始或從1開始都可以,也就是把 0 次方作為 1 和把 m-1 次方作為 1 的區別,對應係數的時候要根據這個注意一下(程式碼中註釋裡的方案也可);

初始化一個多項式並不是把每個係數都賦值1!而只有第0項是1,這樣別的多項式乘過來還是那個多項式;

然後要特別注意讀入時去掉0!因為原根系列中沒有模出0的,所以以原根為基礎的 NTT 算的時候不能考慮0,而反正最後要求的方案中,x >= 1,一旦有0,乘積就是0了,所以0對答案沒有影響,就當沒給這個數算即可;

一下午的心血...

程式碼如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
typedef 
long long ll; int const xn=(1<<14),xm=8005,mod=1004535809; int n,m,rev[xn],g,a[xn],b[xn],lim,r[xm],cnt,pri[xm],inv; int rd() { int ret=0,f=1; char ch=getchar(); while(ch<'0'||ch>'9'){if(ch=='-')f=0; ch=getchar();} while(ch>='0'&&ch<='9')ret=(ret<<3)+(ret<<1)+ch-'0',ch=getchar();
return f?ret:-ret; } int pw(ll a,int b,int md) { ll ret=1; for(;b;b>>=1,a=(a*a)%md)if(b&1)ret=(ret*a)%md; return ret; } void div(int x) { for(int i=2;i*i<=x;i++) { if(x%i)continue; pri[++cnt]=i; while(x%i==0)x/=i; } if(x>1)pri[++cnt]=x; } void init() { lim=1; int l=0; while(lim<=m+m)lim<<=1,l++; //while(lim<=2*(m-1))lim<<=1,l++; for(int i=0;i<lim;i++) rev[i]=((rev[i>>1]>>1)|((i&1)<<(l-1))); inv=pw(lim,mod-2,mod); if(m==2){g=1; return;} div(m-1); for(g=2;;g++) { bool f=0; for(int j=1;j<=cnt;j++) if(pw(g,(m-1)/pri[j],m)==1){f=1; break;} if(!f)break; } for(int i=1,k=g;i<m;i++,k=(ll)k*g%m)r[k]=i; //for(int i=0,k=1;i<m-1;i++,k=(ll)k*g%m)r[k]=i;//k=1 } int upt(int x){while(x>=mod)x-=mod; while(x<0)x+=mod; return x;} void ntt(int *a,int tp) { for(int i=0;i<lim;i++) if(i<rev[i])swap(a[i],a[rev[i]]); for(int mid=1;mid<lim;mid<<=1) { int wn=pw(3,(mod-1)/(mid<<1),mod); if(tp==-1)wn=pw(wn,mod-2,mod);// for(int j=0,len=(mid<<1);j<lim;j+=len) { int w=1; for(int k=0;k<mid;k++,w=(ll)w*wn%mod) { int x=a[j+k],y=(ll)w*a[j+mid+k]%mod; a[j+k]=upt(x+y); a[j+mid+k]=upt(x-y); } } } } void pww() { ntt(a,1); for(int i=0;i<lim;i++)a[i]=(ll)a[i]*a[i]%mod; ntt(a,-1); for(int i=0;i<lim;i++)a[i]=(ll)a[i]*inv%mod; for(int i=m;i<lim;i++)a[i%m+1]=upt(a[i%m+1]+a[i]),a[i]=0;//%m+1 //for(int i=m-1;i<lim;i++)a[i%(m-1)]=upt(a[i%(m-1)]+a[i]),a[i]=0; } void mul() { ntt(a,1); ntt(b,1);// for(int i=0;i<lim;i++)b[i]=(ll)b[i]*a[i]%mod; ntt(a,-1); ntt(b,-1);// for(int i=0;i<lim;i++) a[i]=(ll)a[i]*inv%mod,b[i]=(ll)b[i]*inv%mod; for(int i=m;i<lim;i++) a[i%m+1]=upt(a[i%m+1]+a[i]),a[i]=0, b[i%m+1]=upt(b[i%m+1]+b[i]),b[i]=0; /* for(int i=m-1;i<lim;i++) a[i%(m-1)]=upt(a[i%(m-1)]+a[i]),a[i]=0, b[i%(m-1)]=upt(b[i%(m-1)]+b[i]),b[i]=0; */ } int main() { n=rd(); m=rd(); init(); int p=rd(),num=rd(); for(int i=1,x;i<=num;i++) { x=rd(); if(x)a[r[x]]=1;//x!=0 !! } int t=n; //for(int i=0;i<lim;i++)b[i]=1; b[0]=1;//! for(;t;t>>=1,pww())if(t&1)mul(); printf("%d\n",b[r[p]]); return 0; }