1. 程式人生 > >Elimination Round G. PolandBall and Many Other Balls 倍增+NTT+DP

Elimination Round G. PolandBall and Many Other Balls 倍增+NTT+DP

Description 把n個球分成m組,每一組不超過2個,並且不能為空,對於所有小於等於k的分組輸出有多少種不同的方案。

Sample Input 3 3

Sample Output 5 5 1

對於暴力的DP,設f[i][j]為前i個球分成j的方案數。 那麼f[i][j]=f[i1][j]+f[i1][j1]+f[i2][j1]f[i][j]=f[i-1][j]+f[i-1][j-1]+f[i-2][j-1] 這個東西我們考慮把他變成一個多項式, f(n)就表示{f[n][0],f[n][1],…,f[n][k]} 倍增去維護他。 對於倍增的過程,我們是這樣考慮的。 從高位往低位找二進位制位, 每次將兩個相同的n合起來。 對於某個位置如果有一,那你就暴力搞。 那麼對於兩個多項式考慮合併。 f

(2n)(k)=f(n)2(k)+f(n1)2(k1)f(2n)(k)=f(n)^2(k) + f(n-1)^2(k-1) 這表示不考慮中間兩邊合起來。 中間放2個,兩邊合起來。 用NTT加速這個東西。 於是你同時要維護兩個東西f(n),f(n-1) 於是再考慮f(2n-1)的轉移。。。 f(2n1)(k)=2f(n)f(n1)(k)f(n1)2(k)f(n1)2(k1)f(2n-1)(k)=2f(n)f(n-1)(k)-f(n-1)^2(k)-f(n-1)^2(k-1) 就相當於f(n)放一邊,f(n-1)放一邊,然後再反過來。 但你這樣是有重複的。 重複的部分其實就是n那個位置不放兩邊放n-1 n那個位置放兩邊放n-1。

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;
typedef long long LL;
const LL mod = 998244353;
int _min(int x, int y) {return x < y ? x : y;}
int _max(int x, int y) {return x > y ? x : y;}
int read() {
	int s = 0, f = 1; char ch = getchar();
	while(ch < '0' || ch > '9') {if(ch == '-') f = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9') s = s * 10 + ch - '0', ch = getchar();
	return s * f;
}

LL h1[150000], h2[150000], A[150000], B[150000], C[150000], D[150000], G[150000];
int k, R[150000];

LL pow_mod(LL a, LL k) {
	LL ans = 1;
	while(k) {
		if(k & 1) (ans *= a) %= mod;
		(a *= a) %= mod; k /= 2;
	} return ans;
}

void NTT(LL y[], int len, int on) {
	for(int i = 0; i < len; i++) if(i < R[i]) swap(y[i], y[R[i]]);
	for(int i = 1; i < len; i *= 2) {
		LL wn = pow_mod(3, (LL)(mod - 1) / (i * 2)); if(on == -1) wn = pow_mod(wn, mod - 2);
		for(int j = 0; j < len; j += i * 2) {
			LL w = 1;
			for(int k = 0; k < i; k++) {
				LL u = y[j + k], v = y[j + k + i] * w % mod;
				y[j + k] = (u + v) % mod, y[j + k + i] = (u - v + mod) % mod;
				w = w * wn % mod;
			}
		}
	} if(on == -1) {
		LL tmp = pow_mod(len, mod - 2);
		for(int i = 0; i < len; i++) y[i] = y[i] * tmp % mod;
	}
}

void solve(int len) {
	memset(C, 0, sizeof(C)), memset(D, 0, sizeof(D));
	memcpy(A, h1, sizeof(A)), memcpy(B, h2, sizeof(B));
	NTT(A, len, 1), NTT(B, len, 1);
	for(int i = 0; i < len; i++) G[i] = A[i] * A[i] % mod;
	NTT(G, len, -1);
	for(int i = 0; i <= k; i++) C[i] += G[i];
	for(int i = 0; i < len; i++) G[i] = B[i] * B[i] % mod;
	NTT(G, len, -1);
	for(int i = 1; i <= k; i++) (C[i] += G[i - 1]) %= mod;
	for(int i = 0; i <= k; i++) (D[i] = (D[i] - G[i] + mod) % mod) %= mod;
	for(int i = 1; i <= k; i++) (D[i] = (D[i] - G[i - 1] + mod) % mod) %= mod;
	for(int i = 0; i < len; i++) G[i] = A[i] * B[i] % mod;
	NTT(G, len, -1);
	for(int i = 0; i <= k; i++) (D[i] += G[i] * 2LL % mod) %= mod;
	memcpy(h1, C, sizeof(h1)), memcpy(h2, D, sizeof(h2));
}

void vio() {
	for(int i = 0; i <= k; ++i) C[i] = h2[i];
	for(int i = 0; i <= k; ++i) h2[i] = h1[i];
	for(int i = 1; i <= k; ++i) {
		h1[i] = (h2[i] + h2[i - 1]) % mod;
		h1[i] = (h1[i] + C[i - 1]) % mod;
	} h1[0] = 1;
}

int main() {
	int n = read(); k = read();
	h1[0] = 1;
	int len;
	for(len = 1; len <= 2 * k + 1; len *= 2);
	for(int i = 0; i < len; i++) R[i] = (R[i >> 1] >> 1) | ((i & 1) * (len >> 1));
	for(int i = 30; i >= 0; i--) {
		solve(len);
		if(n >= (1 << i)) vio(), n -= (1 << i);
	} for(int i = 1; i <= k; i++) printf("%lld ", h1[i]);
	return 0;
}