21牛客9C - Cells(LGV引理)
阿新 • • 發佈:2021-08-15
題目
題解
先放一個LGV引理的連結在這裡。
主要講講題解中這裡的推導
每次從後往前,後一列減去\(t\)倍前一列可得
\[(a_i+1)\prod_{k=2}^{j+1}{(a_i+k)}-t(a_i+1)\prod_{k=2}^{j}{(a_i+k)}=(a_i+1)(a_i+j+1-t)\prod_{k=2}^j{(a_i+k)} \]令\(t=j\),可得
\[(a_i+1)^2\prod_{k=2}^j{(a_i+k)} \]重複這個過程,直到連乘的項消去,剩餘第\(j\)列就是\((a_i+1)^j\)。
這個是範德蒙矩陣的形式,具體化簡過程百度或自己推。
最後要求\(\prod_{1\le i < j \le n}{(a_j-a_i)}\),直接卷積求出所有差值的個數然後快速冪即可。由於相同差值至多出現1e5次,可以直接ntt。注意多項式乘法最後度數要乘2,這樣卷積出來的結果才是正確的。
#include <bits/stdc++.h> #define endl '\n' #define IOS std::ios::sync_with_stdio(0); cin.tie(0); cout.tie(0) #define mp make_pair #define seteps(N) fixed << setprecision(N) typedef long long ll; using namespace std; /*-----------------------------------------------------------------*/ ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;} #define INF 0x3f3f3f3f const int N = 3e6; const int M = 998244353; const double eps = 1e-5; int rev[N]; inline ll qpow(ll a, ll b, ll m) { ll res = 1; while(b) { if(b & 1) res = (res * a) % m; a = (a * a) % m; b = b >> 1; } return res; } void change(ll y[], int len) { for(int i = 0; i < len; ++i) { rev[i] = rev[i >> 1] >> 1; if(i & 1) { rev[i] |= len >> 1; } } for(int i = 0; i < len; ++i) { if(i < rev[i]) { swap(y[i], y[rev[i]]); } } return; } void fft(ll y[], int len, int on) { change(y, len); for(int h = 2; h <= len; h <<= 1) { ll gn = qpow(3, (M - 1) / h, M); if(on == -1) gn = qpow(gn, M - 2, M); for(int j = 0; j < len; j += h) { ll g = 1; for(int k = j; k < j + h / 2; k++) { ll u = y[k]; ll t = g * y[k + h / 2] % M; y[k] = (u + t) % M; y[k + h / 2] = (u - t + M) % M; g = g * gn % M; } } } if(on == -1) { ll inv = qpow(len, M - 2, M); for(int i = 0; i < len; i++) { y[i] = y[i] * inv % M; } } } int get(int x) { int res = 1; while(res < x) { res <<= 1; } return res; } ll f1[N], f2[N]; int main() { IOS; int n; cin >> n; int mx = 1000000; int up = 0; ll rj = 1, prod = 1, tj = 1; for(int i = 1; i <= n; i++) { int x; cin >> x; f1[x] = 1; f2[mx - x] = 1; prod = prod * (x + 1) % M; tj = tj * i % M; rj = rj * tj % M; up = x; } rj = qpow(rj, M - 2, M); int len = get(2 * mx + 1); fft(f1, len, 1); fft(f2, len, 1); for(int i = 0; i < len; i++) f1[i] = f1[i] * f2[i] % M; fft(f1, len, -1); ll ans = 1; for(int i = 1; i <= mx; i++) { ans = ans * qpow(i, f1[mx - i], M) % M; } ans = ans * rj % M * prod % M; cout << ans << endl; }