luogu P4705 玩遊戲
https://www.luogu.com.cn/problem/P4705
又碰到一個自己不會的套路,麻了
首先把要計算的式子寫出來
\[\sum_{i=1}^n\sum_{j=1}^m(a_i+b_j)^k \]首先二項式展開
\(\large \sum\limits_{i=1}^n\sum\limits_{j=1}^m\sum\limits_{t=0}^k\binom{k}{t}a_i^tb_j^{k-t}\)
交換一下求和順序,把組合數拆開
\(\large k!\sum\limits_{t=0}^k\sum\limits_{i=1}^n\frac{a_i^t}{t!}\sum\limits_{j=1}^m\frac{b_j^{k-t}}{(k-t)!}\)
考慮生成函式
設\(\large A(x)=\sum\limits_{t=0} \frac{x^t}{t!} \sum\limits_{i=1}^n a_i^t\)
\(\large B(x)=\sum\limits_{t=0} \frac{x^t}{t!} \sum\limits_{i=1}^n b_i^t\)
答案的\(EGF=F(x)\)
容易發現\(F(x)=A(x)*B(x)\)
然後我們來考慮\(A(x)\)怎麼求?
主要是後面那坨\(\sum\limits_{i=1}^n a_i^t\)沒辦法快速計算
我們交換一波求和順序
\(\large \sum\limits_{i=1}^n\sum\limits_{t=0}a_i^tx^t\)
很容易得到這個東西為
\(\large \sum\limits_{i=1}^n\frac{1}{1-a_ix}\)
然而我推到這一步就不會了 /kk
這時候我們需要一個經典套路
注意到\((ln(1-a_ix))'=\frac{-a_i}{1-a_ix}\)
我們上面那條式子顯然可以變為
\(\large \sum\limits_{i=1}^n1-\frac{-a_ix}{1-a_ix}\)
\(= \large n-x\sum\limits_{i=1}^n\frac{-a_i}{1-a_ix}\)
\(= \large n-x\sum\limits_{i=1}^n(ln(1-a_ix))'\)
因為導數的和=和的導數
所以
\(= \large n-x (ln(\prod(1-a_ix)))'\)
裡面那個\(\prod(1-a_ix)\)顯然可以分治\(ntt\)
然後大力計算即可
code:
#include<bits/stdc++.h>
#define N 800050
#define sz(x) ((int)x.size())
#define poly vector<int>
#define mod 998244353
using namespace std;
int add(int x, int y) { x += y;
if(x >= mod) x -= mod;
return x;
}
int sub(int x, int y) { x -= y;
if(x < 0) x += mod;
return x;
}
int mul(int x, int y) {
return 1ll * x * y % mod;
}
int qpow(int x, int y) {
int ret = 1;
for(; y; y >>= 1, x = mul(x, x)) if(y & 1) ret = mul(ret, x);
return ret;
}
const int G = 3;
const int Ginv = qpow(G, mod - 2);
int rev[N << 1];
void ntt(int *a, int n, int o) {
for(int i = 0; i < n; i ++) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) * (n >> 1));
for(int i = 0; i < n; i ++) if(rev[i] > i) swap(a[i], a[rev[i]]);
for(int len = 2; len <= n; len <<= 1) {
int w0 = qpow((o == 1)? G : Ginv, (mod - 1) / len);
for(int j = 0; j < n; j += len) {
int wn = 1;
for(int k = j; k < j + (len >> 1); k ++, wn = mul(wn, w0)) {
int X = a[k], Y = mul(a[k + (len >> 1)], wn);
a[k] = add(X, Y), a[k + (len >> 1)] = sub(X, Y);
}
}
}
int ninv = qpow(n, mod - 2);
if(o == -1)
for(int i = 0; i < n; i ++) a[i] = mul(a[i], ninv);
}
poly operator + (const poly &A, const poly & B) {
poly C = A; C.resize(max(sz(A), sz(B)));
for(int i = 0; i < sz(B); i ++) C[i] = add(C[i], B[i]);
return C;
}
poly operator - (const poly &A, const poly & B) {
poly C = A; C.resize(max(sz(A), sz(B)));
for(int i = 0; i < sz(B); i ++) C[i] = sub(C[i], B[i]);
return C;
}
#define clr(a, n) (memset(a, 0, sizeof(int) * n))
int a[N << 1], b[N << 1], lim;
poly operator * (const poly & A, const poly & B) {
for(int i = 0; i < sz(A); i ++) a[i] = A[i];
for(int i = 0; i < sz(B); i ++) b[i] = B[i];
poly C; C.resize(min(lim, sz(A) + sz(B) - 1));
int len = 1;
for(; len <= sz(A) + sz(B) - 1; len <<= 1);
ntt(a, len, 1), ntt(b, len, 1);
for(int i = 0; i < len; i ++) a[i] = mul(a[i], b[i]);
ntt(a, len, -1);
for(int i = 0; i < sz(C); i ++) C[i] = a[i];
clr(a, len), clr(b, len);
return C;
}
poly operator * (const int & a, const poly & A) {
poly C; C.resize(sz(A));
for(int i = 0; i < sz(A); i ++) C[i] = mul(A[i], a);
return C;
}
void pINV(poly &A, poly &B, int n) {
if(n == 1) B.push_back(qpow(A[0], mod - 2));
else {
pINV(A, B, (n + 1) / 2);
poly C = A; C.resize(n);
B = 2 * B - B * B * C;
B.resize(n);
}
}
poly INV(poly A) {
poly B; pINV(A, B, sz(A));
return B;
}
int inv[N], fac[N], ifac[N];
void init(int n) {
inv[1] = 1;
for(int i = 2; i <= n; i ++)
inv[i] = sub(0, mul(mod / i, inv[mod % i]));
fac[0] = 1;
for(int i = 1; i <= n; i ++) fac[i] = mul(fac[i - 1], i);
ifac[n] = qpow(fac[n], mod - 2);
for(int i = n - 1; i >= 0; i --) ifac[i] = mul(ifac[i + 1], i + 1);
}
poly qiudao(const poly A) {
poly B;
for(int i = 1; i < sz(A); i ++) B.push_back(mul(i, A[i]));
B.pop_back();
return B;
}
poly jifen(const poly A) {
poly B; B.resize(sz(A));
for(int i = 1; i < sz(A); i ++) B[i] = mul(A[i - 1], inv[i]);
return B;
}
poly ln(const poly A) {
return jifen(qiudao(A) * INV(A));
}
void pexp(poly &A, poly & B, int n) {
if(n == 1) B.push_back(1);
else {
pexp(A, B, (n + 1) / 2);
poly lnB; lnB = B; lnB.resize(n);
lnB = ln(lnB);
for(int i = 0; i < sz(lnB); i ++) lnB[i] = sub(A[i], lnB[i]);
lnB[0] = add(lnB[0], 1);
B = B * lnB;
B.resize(n);
}
}
poly exp(poly A) {
poly C; pexp(A, C, sz(A));
return C;
}
poly cdq(poly &a, int l, int r) {
// printf("%d %d\n", l, r);
if(l == r) {
poly b; b.push_back(1); b.push_back(sub(0, a[l]));
// printf("** %d %d\n", l, a[l]);
return b;
}
int mid = (l + r) >> 1;
poly f = cdq(a, l, mid), g = cdq(a, mid + 1, r);
return f * g;
}
int t;
poly get(poly &a, int n) {
poly f = cdq(a, 1, n); f.resize(t + 5);
//f = INV(f) * qiudao(f);
f = ln(f);
f = qiudao(f);
// for(int i = 0; i <= t; i ++) printf("%d ", f[i]); printf("\n");
for(int i = t - 1; i >= 0; i --)
f[i + 1] = sub(0, f[i]);
f[0] = n;
for(int i = 0; i <= t; i ++) f[i] = mul(f[i], ifac[i]);
return f;
}
poly f, g;
int n, m;
int main() {
scanf("%d%d", &n, &m); f.resize(n + 3), g.resize(m + 3);
for(int i = 1; i <= n; i ++) scanf("%d", &f[i]);
for(int i = 1; i <= m; i ++) scanf("%d", &g[i]);//, printf("fuck%d %d ", b[i], i);
scanf("%d", &t); t ++; lim = 2 * max(n, max(m, t));
// f.resize(lim + 1), g.resize(lim + 1);
init(lim);
f = get(f, n), g = get(g, m);
//for(int i = 1; i <= m; i ++) printf("** %d ", b[i]); printf("\n");
// for(int i = 0; i <= t; i ++) printf("%d ", f[i]); printf("\n");
// for(int i = 0; i <= t; i ++) printf("%d ", g[i]); printf("\n");
f = f * g;
int nminv = mul(inv[n], inv[m]);
for(int i = 1; i < t; i ++) {
f[i] = mul(f[i], fac[i]);
printf("%d\n", mul(f[i], nminv));
}
return 0;
}