1. 程式人生 > 其它 >Codeforces1648C. Tyler and Strings

Codeforces1648C. Tyler and Strings

傳送門

題目大意

給定一個長為 \(n(1\leq n\leq 2\times10^5)\) 的整數序列 \(s(1\leq s_{i}\leq2\times10^5)\) ,以及一個長為 \(m(1\leq m\leq 2\times10^5)\) 的整數序列 \(t(1\leq t_{i}\leq2\times10^5)\) ,求有多少種 \(s\) 的排列,其字典序小於 \(t\) ,答案對 \(998244353\) 取模。

思路

如果字典序要小於 \(t\) ,那麼要與 \(t\) 有一個相同的字首,我們可以列舉這個相同字首的長度。當長度為 \(i-1\) 時,第 \(i\) 個數字可以是 \(s\)

中剩下的小於 \(t_{i}\) 的所有數字,然後再乘以後面部分的全排列,設 \(s\) 總共有 \(k\) 種不同的數字,每種數字 \(j\) 在相同字首長度為 \(i-1\) 時剩餘個數為 \(cnt_j\) ,對於第 \(i\) 為每個合法的數字種類 \(j\) ,其貢獻就為 \(\frac{(n-i)!}{cnt_1!cnt_2!...(cnt_j-1)!...cnt_k!}=\frac{(n-i)!}{cnt_1!cnt_2!...cnt_k!}cnt_j\) ,於是我們可以記所有在第 \(i\) 個數字處可以列舉的數字個數和為 \(now\) ,於是長度為 \(i-1\) 的相同長度的總貢獻為 \(\frac{(n-i)!}{cnt_1!cnt_2!...cnt_k!}now\)
\(now\) 我們可以用樹狀陣列維護 \(cnt\) 來求得,一開使可以先預處理出分母,之後的都可以推出。此外還需要特判 \(s\) 整體已經是 \(t\) 字首的這一個可能答案,最後將貢獻相加即可。

程式碼

#include<bits/stdc++.h>
#include<unordered_map>
#include<unordered_set>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
#define all(x) x.begin(),x.end()
//#define int LL
//#define lc p*2+1
//#define rc p*2+2
#define endl '\n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#pragma warning(disable : 4996)
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
const double eps = 1e-8;
const LL MOD = 100000000;
const LL mod = 998244353;
const int maxn = 200010;

LL N, M, S[maxn], T[maxn], cnt[maxn], fact[maxn], invfact[maxn], dat[maxn], n;

void add(LL i, LL x)
{
	while (i <= n)
	{
		dat[i] += x;
		i += i & (-i);
	}
}

LL sum(LL i)
{
	LL ans = 0;
	while (i)
	{
		ans += dat[i];
		i -= i & (-i);
	}

	return ans;
}

LL qpow(LL a, LL x, LL m)
{
	LL ans = 1;
	while (x)
	{
		if (x & 1)
			ans = ans * a % m;
		a = a * a % m;
		x >>= 1;
	}

	return ans;
}

void fact_init(LL n, LL m)
{
	fact[0] = fact[1] = 1;
	for (LL i = 2; i <= n; i++)
		fact[i] = fact[i - 1] * i % m;
	invfact[n] = qpow(fact[n], m - 2, m);;
	for (LL i = n; i > 0; i--)
		invfact[i - 1] = invfact[i] * i % m;
}

void solve()
{
	LL ans = 0, exans = 1;
	LL allInv = 1;
	for (LL i = 1; i <= 200000; i++)
	{
		if (cnt[i])
			allInv = allInv * invfact[cnt[i]] % mod;
	}
	LL mi = min(N, M);
	if (N >= M)
		exans = 0;
	for (LL i = 1; i <= mi; i++)
	{
		LL now = sum(T[i] - 1);
		ans = (ans + now * fact[N - i] % mod * allInv % mod) % mod;
		if (!cnt[T[i]])
		{
			exans = 0;
			break;
		}
		allInv = (allInv * cnt[T[i]]) % mod;
		cnt[T[i]]--;
		add(T[i], -1);
	}
	cout << (ans + exans) % mod << endl;
}

int main()
{
	IOS;
	n = 200000;
	fact_init(200005, mod);
	cin >> N >> M;
	for (int i = 1; i <= N; i++)
	{
		cin >> S[i];
		cnt[S[i]]++;
		add(S[i], 1);
	}
	for (int i = 1; i <= M; i++)
		cin >> T[i];
	solve();

	return 0;
}