Codeforces1648C. Tyler and Strings
阿新 • • 發佈:2022-03-24
題目大意
給定一個長為 \(n(1\leq n\leq 2\times10^5)\) 的整數序列 \(s(1\leq s_{i}\leq2\times10^5)\) ,以及一個長為 \(m(1\leq m\leq 2\times10^5)\) 的整數序列 \(t(1\leq t_{i}\leq2\times10^5)\) ,求有多少種 \(s\) 的排列,其字典序小於 \(t\) ,答案對 \(998244353\) 取模。
思路
如果字典序要小於 \(t\) ,那麼要與 \(t\) 有一個相同的字首,我們可以列舉這個相同字首的長度。當長度為 \(i-1\) 時,第 \(i\) 個數字可以是 \(s\)
程式碼
#include<bits/stdc++.h> #include<unordered_map> #include<unordered_set> using namespace std; typedef long long LL; typedef unsigned long long ULL; typedef pair<int, int> PII; #define all(x) x.begin(),x.end() //#define int LL //#define lc p*2+1 //#define rc p*2+2 #define endl '\n' #define inf 0x3f3f3f3f #define INF 0x3f3f3f3f3f3f3f3f #pragma warning(disable : 4996) #define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0) const double eps = 1e-8; const LL MOD = 100000000; const LL mod = 998244353; const int maxn = 200010; LL N, M, S[maxn], T[maxn], cnt[maxn], fact[maxn], invfact[maxn], dat[maxn], n; void add(LL i, LL x) { while (i <= n) { dat[i] += x; i += i & (-i); } } LL sum(LL i) { LL ans = 0; while (i) { ans += dat[i]; i -= i & (-i); } return ans; } LL qpow(LL a, LL x, LL m) { LL ans = 1; while (x) { if (x & 1) ans = ans * a % m; a = a * a % m; x >>= 1; } return ans; } void fact_init(LL n, LL m) { fact[0] = fact[1] = 1; for (LL i = 2; i <= n; i++) fact[i] = fact[i - 1] * i % m; invfact[n] = qpow(fact[n], m - 2, m);; for (LL i = n; i > 0; i--) invfact[i - 1] = invfact[i] * i % m; } void solve() { LL ans = 0, exans = 1; LL allInv = 1; for (LL i = 1; i <= 200000; i++) { if (cnt[i]) allInv = allInv * invfact[cnt[i]] % mod; } LL mi = min(N, M); if (N >= M) exans = 0; for (LL i = 1; i <= mi; i++) { LL now = sum(T[i] - 1); ans = (ans + now * fact[N - i] % mod * allInv % mod) % mod; if (!cnt[T[i]]) { exans = 0; break; } allInv = (allInv * cnt[T[i]]) % mod; cnt[T[i]]--; add(T[i], -1); } cout << (ans + exans) % mod << endl; } int main() { IOS; n = 200000; fact_init(200005, mod); cin >> N >> M; for (int i = 1; i <= N; i++) { cin >> S[i]; cnt[S[i]]++; add(S[i], 1); } for (int i = 1; i <= M; i++) cin >> T[i]; solve(); return 0; }