hdu 6270 Marriage
阿新 • • 發佈:2018-12-10
題解:常見的二項式反演,用啟發式合併(每次選最小的兩個集合作NTT,類比下線段樹)加速NTT即可。複雜度(n logn l ogn)
#include"bits/stdc++.h" using namespace std; typedef long long LL; const int MX = 4e5+7; const int mod = 998244353; const int P = 998244353, G = 3; const int NUM = 20; LL wn[NUM]; LL va[MX],vb[MX]; LL mul(LL x,LL y)//乘法超ll用快速乘,主函式也需要用 { LL ans=(x*y-(LL)((long double)x/mod*y+1e-8)*mod); return ans<0?ans+mod:ans; } LL quick_mod(LL a, LL x, LL mod) { LL ans = 1; a %= mod; while(x) { if(x & 1)ans = ans * a % mod; x >>= 1; a = a * a % mod; } return ans; } //在程式的開頭就要放 void GetWn() { for(int i = 0; i < NUM; i++) { int t = 1 << i; wn[i] = quick_mod(G, (P - 1) / t, P); } } void Rader(LL F[], int len) { int j = len >> 1; for(int i = 1; i < len - 1; i++) { if(i < j) swap(F[i], F[j]); int k = len >> 1; while(j >= k)j -= k,k >>= 1; if(j < k) j += k; } } void NTT(LL F[], int len, int t) { Rader(F, len); int id = 0; for(int h = 2; h <= len; h <<= 1) { id++; for(int j = 0; j < len; j += h) { LL E = 1; for(int k = j; k < j + h / 2; k++) { LL u = F[k]; LL v = E * F[k + h / 2] % P; F[k] = (u + v) % P; F[k + h / 2] = (u - v + P) % P; E = E * wn[id] % P; } } } if(t == -1) { for(int i = 1; i < len / 2; i++)swap(F[i], F[len - i]); LL inv = quick_mod(len, P - 2, P); for(int i = 0; i < len; i++)F[i] = F[i] * inv % P; } } void Conv(LL a[], LL b[], int len) { NTT(a, len, 1); NTT(b, len, 1); for(int i = 0; i < len; i++) a[i] = mul(a[i],b[i]); NTT(a, len, -1); } int n; LL fac[MX],inv[MX]; void init() { int n = 1e5; fac[0] = fac[1] = inv[0] = inv[1] = 1; for(int i = 2; i <= n; i++){ fac[i] = fac[i-1]*i%mod; inv[i] = (mod-mod/i)*inv[mod%i]%mod; } for(int i = 2; i <= n; i++) inv[i] = inv[i-1]*inv[i]%mod; } LL C(int n, int m) { if(n < m) return 0; return fac[n]*inv[m]%mod*inv[n-m]%mod; } vector<LL> v[MX]; struct node{ int id,sz; node(){} node(int id, int sz) : id(id),sz(sz){} bool operator < (const node &a) const{ return sz > a.sz; } }; priority_queue<node> q; void work(vector<LL> &a, vector<LL> &b, node &c) { int mx = (int)a.size() + (int)b.size() - 1; int len = 1; while(len <= mx) len <<= 1; for(int i = 0; i < len; i++) va[i] = vb[i] = 0; for(int i = 0; i < a.size(); i++) va[i] = a[i]; for(int i = 0; i < b.size(); i++) vb[i] = b[i]; Conv(va,vb,len); a.clear(); for(int i = 0; i < mx; i++) a.push_back(va[i]); c.sz = a.size(); } int main() { #ifdef LOCAL freopen("input.txt","r",stdin); #endif // LOCAL int T; init(); GetWn(); scanf("%d",&T); while(T--){ int sum = 0; scanf("%d",&n); for(int i = 1, man,fem; i <= n; i++){ scanf("%d%d",&man,&fem); sum += man; v[i].clear(); for(int j = 0; j <= min(man,fem); j++){ v[i].push_back(C(man,j)*C(fem,j)%mod*fac[j]); } q.push(node(i,v[i].size())); } while(q.size() >= 2){ node a = q.top(); q.pop(); node b = q.top(); q.pop(); node c = node(a.id,0); work(v[a.id],v[b.id],c); q.push(c); } node a = q.top(); q.pop(); while(q.size()) q.pop(); LL ans = 0; for(int i = 0; i < a.sz; i++){ ans += quick_mod(-1,i,P)*v[a.id][i]%mod*fac[sum-i]; ans = (ans%mod+mod)%mod; } printf("%lld\n",ans); } }