1. 程式人生 > >(模板)多項式乘法對任意數取模

(模板)多項式乘法對任意數取模

// 多項式乘法 係數對MOD=1000000007取模, 常數巨大,慎用
// 只要選的K個素數乘積大於MOD*MOD*N,理論上MOD可以任取。
#define MOD 1000000007
#define K 3

const int m[K] = {1004535809, 998244353, 104857601};
#define G 3


int qpow(int x, int k, int P) {
	int ret = 1;
	while(k) {
		if(k & 1) ret = 1LL * ret * x % P;
		k >>= 1;
		x = 1LL * x * x % P;
	}
	return ret;
}

struct _NTT {
	int wn[25], P;

	void init(int _P) {
		P = _P;
		for(int i = 1; i <= 21; ++i) {      
			int t = 1 << i;      
			wn[i] = qpow(G, (P - 1) / t, P);      
		}
    }
	void change(int *y, int len) {
		for(int i = 1, j = len / 2; i < len - 1; ++i) {      
			if(i < j) swap(y[i], y[j]);      
			int k = len / 2;      
			while(j >= k) {      
				j -= k;      
				k /= 2;      
			}      
			j += k;      
		} 
	}
	void NTT(int *y, int len, int on) {
		change(y, len);      
		int id = 0;      
      
		for(int h = 2; h <= len; h <<= 1) {      
			++id;      
			for(int j = 0; j < len; j += h) {      
				int w = 1;      
				for(int k = j; k < j + h / 2; ++k) {      
					int u = y[k];      
					int t = 1LL * y[k+h/2] * w % P;     
					y[k] = u + t;      
					if(y[k] >= P) y[k] -= P;      
					y[k+h/2] = u - t + P;      
					if(y[k+h/2] >= P) y[k+h/2] -= P;  
					w = 1LL * w * wn[id] % P;
				}      
			}      
		}      
		if(on == -1) {      
			for(int i = 1; i < len / 2; ++i) swap(y[i], y[len-i]);      
			int inv = qpow(len, P - 2, P);      
			for(int i = 0; i < len; ++i)   
				y[i] = 1LL * y[i] * inv % P;
		}      
	}
	void mul(int A[], int B[], int len) {
		NTT(A, len, 1);
		NTT(B, len, 1);
		for(int i = 0; i < len; ++i) A[i] = 1LL * A[i] * B[i] % P;
		NTT(A, len, -1);
	}
}ntt[K];

int tmp[N][K], t1[N], t2[N];
int r[K][K];

int CRT(int a[]) {
	int x[K];
	for(int i = 0; i < K; ++i) {
		x[i] = a[i];
		for(int j = 0; j < i; ++j) {
			int t = (x[i] - x[j]) % m[i];
			if(t < 0) t += m[i];
			x[i] = 1LL * t * r[j][i] % m[i];
		}
	}
	int mul = 1, ret = x[0] % MOD;
	for(int i = 1; i < K; ++i) {
		mul = 1LL * mul * m[i-1] % MOD;
		ret += 1LL * x[i] * mul % MOD;
		if(ret >= MOD) ret -= MOD;
	}
	return ret;
}

void mul(int A[], int B[], int len) {
	for(int id = 0; id < K; ++id) {

		for(int i = 0; i < len; ++i) {
			t1[i] = A[i];
			t2[i] = B[i];
		}
		ntt[id].mul(t1, t2, len);
		for(int i = 0; i < len; ++i) 
			tmp[i][id] = t1[i];
	}
	for(int i = 0; i < len; ++i) {
		A[i] = CRT(tmp[i]);

	}
}

void init() {
	for(int i = 0; i < K; ++i) {
		for(int j = 0; j < i; ++j) {
			r[j][i] = qpow(m[j], m[i] - 2, m[i]);
		}
	}
	for(int i = 0; i < K; ++i) {
		ntt[i].init(m[i]);
	}
}