1. 程式人生 > 實用技巧 >【洛谷P3321】序列統計

【洛谷P3321】序列統計

題目

題目連結:https://www.luogu.com.cn/problem/P3321
小C有一個集合 \(S\),裡面的元素都是小於 \(m\) 的非負整數。他用程式編寫了一個數列生成器,可以生成一個長度為 \(n\) 的數列,數列中的每個數都屬於集合 \(S\)
小C用這個生成器生成了許多這樣的數列。但是小C有一個問題需要你的幫助:給定整數 \(x\),求所有可以生成出的,且滿足數列中所有數的乘積 \(\bmod \ m\) 的值等於 \(x\) 的不同的數列的有多少個。
小C認為,兩個數列 \(A\)\(B\) 不同,當且僅當 \(\exists i \text{ s.t. } A_i \neq B_i\)

。另外,小C認為這個問題的答案可能很大,因此他只需要你幫助他求出答案對 \(1004535809\) 取模的值就可以了。
\(n\leq 10^9,m\leq 8000\)

思路

之前在 GMOJ 這道題時限開 \(5s\) 被我 \(O(m^2\log n)\) 艹過去了。
首先 \(60\)pts 的倍增 dp 就是設 \(f[i][j]\) 表示選了 \(2^i\) 個數,乘積 \(\bmod p\) 之後的結果為 \(j\) 的方案數。
轉移為

\[f[k][l]=\sum^{}_{i\times j\bmod p=l}f[k-1][i]\times f[k-1][j] \]

然後二進位制拆分即可。
如果這個乘號是加號的話,我們就可以 NTT 優化了。
考慮如何把乘號變為加號,因為 \(\log_ab+\log_ac=\log_a(bc)\)

,所以可以用對數進行轉化。
但是我們需要保證轉化後對於任意兩個 \(x,y\in [1,m)\)\(x\neq y\),都有 \(\log_a x\neq \log_a y\),由於 \(m\) 是質數,所以我們取 \(m\) 的原根即可。
接下來就和 \(60\)pts 的做法一樣了。將每一數轉化為對數之後扔進一個多項式裡,然後倍增計算即可。
時間複雜度 \(O(m\log n\log m)\)

程式碼

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=18010,MOD=1004535809;
int n,m,s,l,G,lim,a[N],rev[N];
ll f[N],g[N],h[N];

ll fpow(ll x,ll k,ll mod=(ll)MOD)
{
	ll ans=1;
	for (;k;k>>=1,x=x*x%mod)
		if (k&1) ans=ans*x%mod;
	return ans;
}

int findg(int p)
{
	vector<int> d;
	for (int i=2;i<=p-1;i++)
		if ((p-1)%i==0) d.push_back(i);
	for (int i=1;i<=p;i++)
	{
		bool flag=1;
		for (int j=0;j<d.size();j++)
			if (fpow(i,(p-1)/d[j],p)==1) { flag=0; break; }
		if (flag) return i;
	}
}

void NTT(ll *f,bool tag)
{
	for (int i=0;i<lim;i++)
		if (i<rev[i]) swap(f[i],f[rev[i]]);
	for (int k=1;k<lim;k<<=1)
	{
		ll tmp=fpow((tag?3:334845270),(MOD-1)/(k<<1));
		for (int i=0;i<lim;i+=(k<<1))
		{
			ll w=1;
			for (int j=0;j<k;j++,w=w*tmp%MOD)
			{
				ll x=f[i+j],y=w*f[i+j+k]%MOD;
				f[i+j]=(x+y)%MOD; f[i+j+k]=(x-y)%MOD;
			}
		}
	}
}

int main()
{
	scanf("%d%d%d%d",&n,&m,&s,&l);  // 十分優雅的讀入
	G=findg(m);
	for (int i=1;i<m;i++)
		a[fpow(G,i,m)]=i;
	for (int i=1,x;i<=l;i++)
	{
		scanf("%d",&x);
		if (x) f[a[x]]++;
	}
	g[0]=lim=1;
	while (lim<=2*m) lim<<=1;
	for (int i=0;i<lim;i++)
		rev[i]=(rev[i>>1]>>1)|((i&1)?(lim>>1):0);
	ll inv=fpow(lim,MOD-2);
	for (int k=0;k<=30;k++)
	{
		if (n&(1<<k))
		{
			memcpy(h,f,sizeof(f));
			NTT(g,1); NTT(h,1);
			for (int i=0;i<lim;i++) g[i]=g[i]*h[i]%MOD;
			NTT(g,0);
			for (int i=1;i<m;i++)
				g[i]=(g[i]+g[i+m-1])*inv%MOD;
			for (int i=m;i<lim;i++) g[i]=0;
		}
		NTT(f,1);
		for (int i=0;i<lim;i++) f[i]=f[i]*f[i]%MOD;
		NTT(f,0);
		for (int i=1;i<m;i++)
			f[i]=((f[i]+f[i+m-1])*inv%MOD+MOD)%MOD;
		for (int i=m;i<lim;i++) f[i]=0;
	}
	printf("%lld",(g[a[s]]%MOD+MOD)%MOD);
	return 0;
}