1. 程式人生 > 實用技巧 >[組合數學][CF1188E]Problem from Red Panda

[組合數學][CF1188E]Problem from Red Panda

題目連結

題目傳送門

簡要題意

有一個長度為 \(k\) 的陣列 \(a\),每次可以選擇一個 \(1\le i\le k\),讓 \(a_i\) 加上 \(k-1\),並對於所有的 \(j\ne i\)\(a_j\) 減掉 \(1\),任何時候必須保證 \(a\) 陣列非負。

求通過任意多次(可以為 \(0\) 次)操作,能達到的不同的 \(a\) 陣列方案數膜 \(998,244,353\) 後的結果。

資料範圍:\(2\le k\le 10^5\)\(a_i\ge 0\)\(\sum_{i=1}^ka_i\le10^6\),但實際上存在複雜度與 \(\sum_{i=1}^ka_i\) 無關的做法。

Solution

\(x_i\) 表示 \(a_i\) 被選中的次數,考慮一個 \(\{x_i\}\) 合法的條件。記 \(s=\sum_{i=1}^kx_i\)

我們不妨把操作看作陣列 \(a\) 整體減 \(1\) 之後 \(a_i+=k\)

顯然我們必須保證最後的 \(a\) 陣列非負,故 \(a_i-s+kx_i\ge 0\),也就是 \(x_i\ge\lceil\frac{\max(s-a_i,0)}k\rceil\)

在這個條件下,判斷是否對於 \(0\le t<s\) 滿足 \(t\) 輪操作之後 \(a\) 陣列非負,只需將所有 \(a_i\) 減掉 \(t\),然後嘗試用 \(t\)

\(k\) 來填充為負的 \(a_i\) 值,判斷是否能夠填充成功即可,即 \(\sum_{i=1}^k\lceil\frac{\max(t-a_i,0)}k\rceil\le t\)。顯然在 \(x_i\ge\lceil\frac{\max(s-a_i,0)}k\rceil\) 的限制下,第 \(i\) 個數被操作的次數不會超過 \(x_i\)

於是一個 \(\{x_i\}\) 合法的條件為:

(1)\(x_i\ge\lceil\frac{\max(s-a_i,0)}k\rceil\)

(2)對於所有 \(0\le t\le s\) 都有 \(\sum_{i=1}^k\lceil\frac{\max(t-a_i,0)}k\rceil\le t\)

在從小到大列舉 \(s\) 的過程中,\(\sum_{i=1}^k\lceil\frac{\max(s-a_i,0)}k\rceil\) 容易求出,將 \(a\) 排序後用指標維護 \(a_i<s\) 的部分,對 \(a_i\bmod k\) 用個桶維護每種值的出現次數即可。

回到問題,我們不能直接對 \(\{x_i\}\) 計數,因為不同的 \(x\) 陣列可能對應同一個 \(a\) 陣列。

首先我們發現,如果所有的 \(x_i\) 都相等,則這樣的操作對原陣列沒有影響。

也就是說,如果所有的 \(x_i\) 都不為 \(0\),則把所有 \(x_i\) 都減掉 \(1\) 之後會得到一個等價的方案。

同樣地如果將一部分 \(x_i\)(個數在 \([1,k-1]\) 之間)減掉 \(1\),則得到的方案一定不等價。

故可以轉化成對 \(\{x_i\}\) 陣列計數,但 \(x\) 陣列必須滿足至少有一個 \(0\)。易得這時有 \(0\le s\le\max a_i\)

從小到大列舉 \(s\),遇到 \(w=\sum_{i=1}^k\lceil\frac{\max(s-a_i,0)}k\rceil>s\) 的情況立刻 break 掉。

問題轉化成 \(k\) 個變數,其中前 \(r\)\(a_i<s\)\(i\) 個數)個變數有一個取值下界 \(down_i\),滿足 \(down_i\ge 1\)\(w=\sum_{i=1}^rdown_i\),求為這 \(k\) 個變數取值,使得至少有一個 \(0\),並且所有變數的和為 \(s\) 的方案數。

先去掉下界 \(down\),轉成所有變數的和為 \(s-w\),並且後 \(k-r\) 個變數至少有一個 \(0\)

考慮容斥,用任意方案減掉沒有 \(0\) 的方案,由插板法得方案數:

\[\binom{s-w+k-1}{k-1}-\binom{s-w+r-1}{k-1} \]

總複雜度 \(O(k\log k+\max a_i)\)

Code

#include <bits/stdc++.h>

template <class T>
inline void read(T &res)
{
	res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	if (bo) res = ~res + 1;
}

const int N = 2e6 + 5, djq = 998244353;

int k, n, a[N], cnt[N], fac[N], inv[N], ans;

int C(int n, int m) {return 1ll * fac[n] * inv[m] % djq * inv[n - m] % djq;}

int main()
{
	fac[0] = inv[0] = inv[1] = 1;
	for (int i = 1; i < N; i++) fac[i] = 1ll * fac[i - 1] * i % djq;
	for (int i = 2; i < N; i++) inv[i] = 1ll * (djq - djq / i) * inv[djq % i] % djq;
	for (int i = 2; i < N; i++) inv[i] = 1ll * inv[i] * inv[i - 1] % djq;
	read(k); int cur = 0;
	for (int i = 1; i <= k; i++) read(a[i]), n += a[i];
	std::sort(a + 1, a + k + 1);
	for (int i = 0, j = 1; i <= a[k]; i++)
	{
		while (a[j] < i) cnt[a[j++] % k]++;
		cur += cnt[(i - 1 + k) % k];
		if (cur > i) return std::cout << ans << std::endl, 0;
		ans = (ans + C(i - cur + k - 1, k - 1)) % djq;
		if (i - cur + j - 2 >= k - 1)
			ans = (ans - C(i - cur + j - 2, k - 1) + djq) % djq;
	}
	return std::cout << ans << std::endl, 0;
}