【題解】Luogu-P3301 [SDOI2013]方程
阿新 • • 發佈:2021-12-17
Description
給定方程及不等式組
\[\begin{cases} x_1+x_2+\cdots+x_n=m\\ \\ x_1\le a_1\\ x_2\le a_2\\ \cdots\\ x_{n1}\le a_{n1}\\ \\ x_{n1+1}\ge a_{n1+1}\\ x_{n1+2}\ge a_{n1+2}\\ \cdots\\ x_{n1+n2}\ge a_{n1+n2} \end{cases} \]請求出該方程組的正整數解的個數 \(\bmod p\)。
- 對於 \(100\%\) 的資料:\(n\le 10^9,n1\le 8,n2\le 8,m\le 10^9,p\le 437367875,T\le 5,1\le a_{1\dots n1+n2}\le m,n1+n2\le n\)
Solution
前置芝士:
-
基礎的計數 + 組合知識
-
exLucas
對於形如 \(x_i\ge a_i\) 的,用小奧思路將 \(m\gets m-(a_i-1)\),這時限制就變成了 \(x_i\ge 1\),也就是去掉了限制。
對於形如 \(x\le a_i\) 的,反面考慮 \(x>a_i\),即 \(x\ge a_i+1\),其它無限制的情況數,然後就和上面一樣了。
注意一下容斥。
假設當前為 \(nowm\),那麼根據插板法,情況數就為 \(C_{nowm-1}^{n-1}\) ,這裡直接用 exLucas
即可。
時間複雜度為 \(O(n1!\cdot p\log m)\)
但是你需要堅信它是跑不滿的(
然後 \(70\) 了。
億些小優化:
- 提前分解 \(p\)
Code
//18 = 9 + 9 = 18. #include <iostream> #include <cstdio> #define Debug(x) cout << #x << "=" << x << endl #define int long long using namespace std; int qpow(int a, int b, int p) { int base = a, ans = 1; while (b) { if (b & 1) { ans = ans * base % p; } base = base * base % p; b >>= 1; } return ans; } int fac[10]; int cal(int n, int p, int pos, int pa) { if (!n) { return 1; } int ans = qpow(fac[pos], n / pa, pa); for (int i = 1; i <= n % pa; i++) { if (i % p) { ans = ans * i % pa; } } return ans * cal(n / p, p, pos, pa) % pa; } int cnt_p(int n, int m, int p) { int cnt = 0; for (int i = p; i <= n; i *= p) { cnt += n / i; } for (int i = p; i <= m; i *= p) { cnt -= m / i; } for (int i = p; i <= n - m; i *= p) { cnt -= (n - m) / i; } return cnt; } int x, y; void exgcd(int a, int b) { if (!b) { x = 1, y = 0; return; } exgcd(b, a % b); int tmp = x; x = y; y = tmp - a / b * y; } int inv(int a, int p) { exgcd(a, p); x = (x % p + p) % p; return x; } int C(int n, int m, int p, int pos, int pa) { int a = cal(n, p, pos, pa), b = cal(m, p, pos, pa), c = cal(n - m, p, pos, pa), cnt = cnt_p(n, m, p); return a * inv(b, pa) % pa * inv(c, pa) % pa * qpow(p, cnt, pa) % pa; } int prime[10], a[10], b[10]; int CRT(int n) { int m = 1; for (int i = 1; i <= n; i++) { m *= a[i]; } int ans = 0; for (int i = 1; i <= n; i++) { int mi = m / a[i]; int Mi = inv(mi, a[i]); ans = (ans + b[i] * mi % m * Mi % m) % m; } return ans; } int k; void pre(int p) { for (int i = 2; i * i <= p; i++) { if (p % i == 0) { prime[++k] = i; a[k] = 1; while (p % i == 0) { a[k] *= i; p /= i; } } } if (p > 1) { prime[++k] = p; a[k] = p; } for (int i = 1; i <= k; i++) { fac[i] = 1; for (int j = 1; j <= a[i]; j++) { if (j % prime[i]) { fac[i] = fac[i] * j % a[i]; } } } } int exLucas(int n, int m) { if (n < m) { return 0; } for (int i = 1; i <= k; i++) { b[i] = C(n, m, prime[i], i, a[i]); } return CRT(k); } int p, n, n1, ans; int w[20]; void dfs(int tot, int bound, int nega, int nowm) { // Debug(nowm), Debug(nega); // Debug(exLucas(nowm - 1, n - 1, p)); ans = (ans + nega * exLucas(nowm - 1, n - 1) + p) % p; if (tot > n1) { return; } for (int i = bound; i <= n1; i++) { dfs(tot + 1, i + 1, -nega, nowm - w[i]); } } signed main() { int t; scanf("%lld%lld", &t, &p); pre(p); while (t--) { int n2, m; scanf("%lld%lld%lld%lld", &n, &n1, &n2, &m); for (int i = 1; i <= n1 + n2; i++) { scanf("%lld", w + i); } for (int i = 1; i <= n2; i++) { m -= (w[n1 + i] - 1); } ans = 0; dfs(1, 1, 1, m); printf("%lld\n", ans); } return 0; }