【洛谷P3321】序列統計
阿新 • • 發佈:2021-01-07
題目
題目連結:https://www.luogu.com.cn/problem/P3321
小C有一個集合 \(S\),裡面的元素都是小於 \(m\) 的非負整數。他用程式編寫了一個數列生成器,可以生成一個長度為 \(n\) 的數列,數列中的每個數都屬於集合 \(S\)。
小C用這個生成器生成了許多這樣的數列。但是小C有一個問題需要你的幫助:給定整數 \(x\),求所有可以生成出的,且滿足數列中所有數的乘積 \(\bmod \ m\) 的值等於 \(x\) 的不同的數列的有多少個。
小C認為,兩個數列 \(A\) 和 \(B\) 不同,當且僅當 \(\exists i \text{ s.t. } A_i \neq B_i\)
\(n\leq 10^9,m\leq 8000\)。
思路
之前在 GMOJ 這道題時限開 \(5s\) 被我 \(O(m^2\log n)\) 艹過去了。
首先 \(60\)pts 的倍增 dp 就是設 \(f[i][j]\) 表示選了 \(2^i\) 個數,乘積 \(\bmod p\) 之後的結果為 \(j\) 的方案數。
轉移為
然後二進位制拆分即可。
如果這個乘號是加號的話,我們就可以 NTT 優化了。
考慮如何把乘號變為加號,因為 \(\log_ab+\log_ac=\log_a(bc)\)
但是我們需要保證轉化後對於任意兩個 \(x,y\in [1,m)\) 且 \(x\neq y\),都有 \(\log_a x\neq \log_a y\),由於 \(m\) 是質數,所以我們取 \(m\) 的原根即可。
接下來就和 \(60\)pts 的做法一樣了。將每一數轉化為對數之後扔進一個多項式裡,然後倍增計算即可。
時間複雜度 \(O(m\log n\log m)\)。
程式碼
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int N=18010,MOD=1004535809; int n,m,s,l,G,lim,a[N],rev[N]; ll f[N],g[N],h[N]; ll fpow(ll x,ll k,ll mod=(ll)MOD) { ll ans=1; for (;k;k>>=1,x=x*x%mod) if (k&1) ans=ans*x%mod; return ans; } int findg(int p) { vector<int> d; for (int i=2;i<=p-1;i++) if ((p-1)%i==0) d.push_back(i); for (int i=1;i<=p;i++) { bool flag=1; for (int j=0;j<d.size();j++) if (fpow(i,(p-1)/d[j],p)==1) { flag=0; break; } if (flag) return i; } } void NTT(ll *f,bool tag) { for (int i=0;i<lim;i++) if (i<rev[i]) swap(f[i],f[rev[i]]); for (int k=1;k<lim;k<<=1) { ll tmp=fpow((tag?3:334845270),(MOD-1)/(k<<1)); for (int i=0;i<lim;i+=(k<<1)) { ll w=1; for (int j=0;j<k;j++,w=w*tmp%MOD) { ll x=f[i+j],y=w*f[i+j+k]%MOD; f[i+j]=(x+y)%MOD; f[i+j+k]=(x-y)%MOD; } } } } int main() { scanf("%d%d%d%d",&n,&m,&s,&l); // 十分優雅的讀入 G=findg(m); for (int i=1;i<m;i++) a[fpow(G,i,m)]=i; for (int i=1,x;i<=l;i++) { scanf("%d",&x); if (x) f[a[x]]++; } g[0]=lim=1; while (lim<=2*m) lim<<=1; for (int i=0;i<lim;i++) rev[i]=(rev[i>>1]>>1)|((i&1)?(lim>>1):0); ll inv=fpow(lim,MOD-2); for (int k=0;k<=30;k++) { if (n&(1<<k)) { memcpy(h,f,sizeof(f)); NTT(g,1); NTT(h,1); for (int i=0;i<lim;i++) g[i]=g[i]*h[i]%MOD; NTT(g,0); for (int i=1;i<m;i++) g[i]=(g[i]+g[i+m-1])*inv%MOD; for (int i=m;i<lim;i++) g[i]=0; } NTT(f,1); for (int i=0;i<lim;i++) f[i]=f[i]*f[i]%MOD; NTT(f,0); for (int i=1;i<m;i++) f[i]=((f[i]+f[i+m-1])*inv%MOD+MOD)%MOD; for (int i=m;i<lim;i++) f[i]=0; } printf("%lld",(g[a[s]]%MOD+MOD)%MOD); return 0; }