1. 程式人生 > 實用技巧 >題解 P3321 【[SDOI2015]序列統計】

題解 P3321 【[SDOI2015]序列統計】

題目連結

Solution [SDOI2015]序列統計

題目大意:給定一個集合 \(S\),求生成一個長為 \(n\) ,每個元素屬於 \(S\),所有元素積 \(mod\;m=x\) 的數列的方案數。\(m\) 為質數。

原根,NTT


分析:首先我們想辦法把乘法變為加法,這樣我們就可以用卷積來計算了。

特判掉 \(S\) 中有 \(0\) 的情況。

由於 \(m\) 為質數,所以它一定有原根,因而乘法群可以變為加法群,乘法在 \(mod\;m\) 意義下運算,加法在 \(mod\;\varphi(m)=m-1\) 意義下運算。

問題轉換為給定一個多項式,求它的 \(n\) 次冪中,所有指數 \(mod\;(m-1)\)

同餘於 \(ln(x)\) 的係數和。

對快速冪進行些許修改即可。兩個 \(m-2\) 次多項式相乘,將 \(m-1\) 項以及之後的係數累加給前面\(0-(m-2)\)

#include <cstdio>
#include <cctype>
#include <cstring>
#include <vector>
#define debug(...) fprintf(stderr,__VA_ARGS__)
using namespace std;
typedef long long ll;
constexpr int maxn = 1e5,mod = 1004535809,G = 3,invG = 334845270,inf = 0x7fffffff;
constexpr int add(const int a,const int b,const int mod = ::mod){return (a + b) % mod;}
constexpr int sub(const int a,const int b,const int mod = ::mod){return (a - b + mod) % mod;}
constexpr int mul(const int a,const int b,const int mod = ::mod){return (1ll * a * b) % mod;}
constexpr int qpow(int base,int b,const int mod = ::mod){//-std=c++17
	int res = 1;
	while(b){
		if(b & 1)res = mul(res,base,mod);
		base = mul(base,base,mod);
		b >>= 1;
	}
	return res;
}
constexpr int inv(const int x,const int mod = ::mod){return qpow(x,mod - 2,mod);}//-std=c++17
constexpr int calc(const int a,const int b,const int mod = ::mod){return mul(a,inv(b,mod),mod);}
struct IO{//-std=c++11,with cstdio and cctype
	private:
		static constexpr int ibufsiz = 1 << 20;
		char ibuf[ibufsiz + 1],*inow = ibuf,*ied = ibuf;
		static constexpr int obufsiz = 1 << 20;
		char obuf[obufsiz + 1],*onow = obuf;
		const char *oed = obuf + obufsiz;
	public:
		char getchar(){
			#ifndef ONLINE_JUDGE
				return ::getchar();
			#else
				if(inow == ied){
					ied = ibuf + sizeof(char) * fread(ibuf,sizeof(char),ibufsiz,stdin);
					*ied = '\0';
					inow = ibuf;
				}
				return *inow++;
			#endif
		}
		template<typename T>
		void read(T &x){
			static bool flg;flg = 0;
			x = 0;char c = getchar();
			while(!isdigit(c))flg = c == '-' ? 1 : flg,c = getchar();
			while(isdigit(c))x = x * 10 + c - '0',c = getchar();
			if(flg)x = -x;
		}
		template <typename T,typename ...Y>
		void read(T &x,Y&... X){read(x);read(X...);}
		int readi(){static int res;read(res);return res;}
		long long readll(){static long long res;read(res);return res;}
		
		void flush(){
			fwrite(obuf,sizeof(char),onow - obuf,stdout);
			fflush(stdout);
			onow = obuf;
		}
		void putchar(char c){
			#ifndef ONLINE_JUDGE
				::putchar(c);
			#else
				*onow++ = c;
				if(onow == oed){
					fwrite(obuf,sizeof(char),obufsiz,stdout);
					onow = obuf;
				}
			#endif
		}
		template <typename T>
		void write(T x,char split = '\0'){
			static unsigned char buf[64];
			if(x < 0)putchar('-'),x = -x;
			int p = 0;
			do{
				buf[++p] = x % 10;
				x /= 10;
			}while(x);
			for(int i = p;i >= 1;i--)putchar(buf[i] + '0');
			if(split != '\0')putchar(split);
		}
		void lf(){putchar('\n');}
		~IO(){
			fwrite(obuf,sizeof(char),onow - obuf,stdout);
		}
}io;

int tr[maxn],len;
void gettr(const int x){
	for(len = 1;len < x;len <<= 1);
	for(int i = 1;i < len;i++)tr[i] = (tr[i >> 1] >> 1) | ((i & 1) ? (len >> 1) : 0);
}
struct poly : std::vector<int>{
	using std::vector<int>::vector;
	#define f (*this)
	void ntt(const int flg = 1){
		const int n = size();
		for(int i = 0;i < n;i++)
			if(i < tr[i])std::swap(f[i],f[tr[i]]);
		for(int p = 2;p <= n;p <<= 1){
			const int len = p >> 1;
			const int unit = qpow(flg == 1 ? G : invG,(mod - 1) / p);
			for(int k = 0;k < n;k += p){
				int now = 1;
				for(int l = k;l < k + len;l++){
					const int tt = mul(f[l + len],now);
					f[l + len] = sub(f[l],tt);
					f[l] = add(f[l],tt);
					now = mul(now,unit);
				}
			}
		}
		if(flg == -1){
			const int inv = ::inv(n);
			for(int i = 0;i < n;i++)f[i] = mul(f[i],inv);
		}
	}
	poly operator * (const poly &g)const{
		poly res(size());
		for(unsigned int i = 0;i < size();i++)res[i] = mul(f[i],g[i]);
		return res;
	}
	poly operator + (const poly &g)const{
		poly res(size());
		for(unsigned int i = 0;i < size();i++)res[i] = add(f[i],g[i]);
		return res;
	}
	#undef f
};
namespace getg{
	constexpr int maxn = 1e4;
	int fac[maxn];
	void sieve(){
		static bool vis[maxn];
		static vector<int> pri;
		for(int i = 2;i < maxn;i++){
			if(!vis[i]){
				pri.push_back(i);
				fac[i] = i;
			}
			for(int x : pri){
				if(1ll * x * i >= maxn)break;
				vis[x * i] = 1;
				fac[x * i] = x;
				if(i % x == 0)break;
			}
		}
	}
	int solve(const int p){
		static auto chk = [&](const int g){
			int now = p - 1;
			while(now != 1){
				if(qpow(g,(p - 1) / fac[now],p) == 1)return false;
				now /= fac[now];
			}
			return true;
		};
		for(int i = 2;;i++)
			if(chk(i))return i;
		return -1;
	}
}
poly trans;
int n,m,x,siz,g,ln[maxn],exp[maxn];
poly qpow(poly base,int b){
	gettr((m - 1) << 1);
	poly res(m - 1);res[0] = 1;
	while(b){
		if(b & 1){
			res.resize(len),base.resize(len);
			res.ntt(),base.ntt();
			res = res * base;
			res.ntt(-1),base.ntt(-1);
			for(int i = 0;i < m - 2;i++)res[i] = add(res[i],res[i + m - 1]);
			res.resize(m - 1),base.resize(m - 1);
		}
		base.resize(len);
		base.ntt();
		base = base * base;
		base.ntt(-1);
		for(int i = 0;i < m - 2;i++)base[i] = add(base[i],base[i + m - 1]);
		base.resize(m - 1);
		b >>= 1;
	}
	return res;
}
int main(){
#ifndef ONLINE_JUDGE
	freopen("fafa.in","r",stdin);
#endif
	getg::sieve();
	io.read(n,m,x,siz);
	g = getg::solve(m);
	exp[0] = 1;
	for(int i = 1;i <= m - 1;i++)exp[i] = mul(exp[i - 1],g,m);
	for(int i = 0;i < m - 1;i++)ln[exp[i]] = i;
	trans.resize(m - 1);
	for(int x,i = 1;i <= siz;i++){
		io.read(x);
		if(x)trans[ln[x]]++;
		else if(!::x){
			io.write(sub(qpow(siz,n),qpow(siz - 1,n)),'\n');
			return 0;
		}
	}
	poly &&ans = qpow(trans,n);
	io.write(ans[ln[x]],'\n');
	return 0;
}