1. 程式人生 > 實用技巧 >CF1444B Divide and Sum

CF1444B Divide and Sum

題目來源:Codeforces Round #680 (Div. 1, based on Moscow Team Olympiad)/Codeforces Round #680 (Div. 2, based on Moscow Team Olympiad),CF1444B/CF1445D,Divide and Sum

說明:這個做法用到了 NTT,複雜度是 \(O(n\log n)\) 的。複雜度並不是最優,而且做法比較複雜。但是我個人挺喜歡這個做法的,感覺一步一步推下來,很有邏輯,也容易在考場上想到。雖然不是很好寫...QwQ

題目大意

題目連結

給定一個長度為 \(2n\) 的序列 \(a\)。現在要將 \(a\)

劃分為兩個子序列 \(p,q\),兩個子序列長度都恰好為 \(n\),且無公共元素。

劃分完成後,我們將 \(p,q\) 分別按從小到大從大到小排序,記得到的兩個序列分別為 \(x,y\)

我們規定一種劃分方式的權值為:\(f(p,q) = \sum_{i = 1}^{n}|x_i - y_i|\)

請求出所有劃分方式的權值之和。答案對 \(998244353\) 取模。

資料範圍:\(1\leq n\leq 150000\)\(1\leq a_i\leq 10^9\)

本題題解

首先,因為得到的兩個子序列是要排序的,所以元素的初始順序其實不重要。那麼可以先將 \(a\) 序列排序。以下討論 \(a\)

序列都是指排好序後的序列。

考慮 \(p,q\) 的第 \(i\) 位對答案的貢獻 (\(1\leq i\leq n\))。列舉 \(p\)\(i\) 位上的元素,記為 \(a_j\);列舉 \(q\)\(i\) 位上的元素,記為 \(a_k\) (\(k\neq j\))。那麼 \(a_j\) 必須恰好是 \(p\) 序列裡第 \(i\) 大的,\(a_k\) 必須恰好是 \(q\) 序列裡第 \(n-i+1\) 大的,也就是說,它們前面分別恰有 \(i-1\) 個 / \(n-i\) 個自己序列的元素。那麼不難用組合數求出,\(|a_j-a_k|\) 在第 \(i\) 位上對答案貢獻的方案數。具體來說:

  • \(k<j\) 時,對答案的貢獻是:\({k-1\choose n-i}{j-k-1\choose(i - 1) - (k - 1 - (n - i))}{2n-j\choose n-i}(a_j-a_k)\),化簡一下,等於:\({k-1\choose n-i}{j-k-1\choose n - k}{2n-j\choose n-i}(a_j-a_k)\)。這三個組合數,含義分別是:在 \(k\) 前面選出 \(q\) 序列裡的元素(剩下的都在 \(p\) 序列裡);在 \(k,j\) 之間選出 \(p\) 序列裡的元素;在 \(j\) 後面選出 \(p\) 序列裡的元素(剩下的都在 \(q\) 序列裡)。
  • \(j<k\) 時,對答案的貢獻是:\({j - 1\choose i - 1}{k - j - 1\choose (n - i) - (j - 1 - (i - 1))}{2n-k\choose i-1}(a_k-a_j)\),化簡一下,等於:\({j-1\choose i-1}{k-j-1\choose n - j}{2n-k\choose i-1}(a_k-a_j)\)。這三個組合數,和前面類似,含義分別是:在 \(j\) 前面選出 \(p\) 序列裡的元素;在 \(j,k\) 之間選出 \(q\) 序列裡的元素;在 \(k\) 後面選出 \(q\) 序列裡的元素。

暴力列舉 \(i,j,k\),按此式子計算答案,時間複雜度 \(O(n^3)\)。這個暴力做法的程式碼片段附在了參考程式碼部分。


繼續優化。發現列舉 \(j\) 再列舉 \(k\) 這件事比較愚蠢。考慮將它們拆開來,也就是分別列舉 \(j,k\)。以列舉 \(j\) 為例。考慮一個 \(j\) 的貢獻,有兩種情況:

  1. \(k<j\),此時這個 \(j\) 對答案的貢獻是 \(a_j\) 乘以一個係數。
  2. \(j<k\),此時這個 \(j\) 對答案的貢獻是 \(-a_j\) 乘以一個係數。

我們要求出這個係數。對於情況 1,相當於要求 \(j\) 前面存在一個合法的 \(k\)。考慮 \(j\) 前面要有哪些東西:

  • \(p\) 序列前 \(i-1\) 小的元素(恰好這麼多,否則 \(j\) 就不是第 \(i\) 了)。
  • \(q\) 序列前 \(n-i+1\) 小的元素(或者更多元素)。

因此,發現 \(j\) 一定要大於等於 \((i-1)+(n-i+1)+1=n+1\)。依次列舉 \(j\in[n+1,2n]\),每個 \(j\) 對答案的貢獻就是:\({j - 1\choose i - 1}\cdot {2n-j\choose n-i}\cdot a_j\)

同理,情況 2 中,\(j\in[1,n]\),對答案的貢獻是:\({j - 1\choose i - 1}\cdot {2n-j\choose n-i}\cdot (-a_j)\)

\(k\) 的貢獻也是類似的。分別是:\(k\in[n+1,2n]\)\({k-1\choose n-i}\cdot{2n-k\choose i-1}\cdot a_k\)\(k\in[1,n]\)\({k-1\choose n-i}\cdot {2n-k\choose i-1}\cdot (-a_k)\)

這樣,我們只需要先列舉 \(i\),再分別列舉 \(j,k\)(而不是套起來)。時間複雜度 \(O(n^2)\)。這個做法的程式碼片段附在了參考程式碼部分。


最後一步,我們把上述 \(n^2\) 的式子拆開,寫成卷積的形式,就可以了。

\(j\in[n+1,2n]\)\({j - 1\choose i - 1}\cdot {2n-j\choose n-i}\cdot a_j\) 為例,可以寫成:

\[\sum_{j=n+1}^{2n}a_j\cdot (j-1)!\cdot (2n-j)!\sum_{i=1}^{n}\frac{1}{(i-1)!(n-i)!}\cdot \frac{1}{(j-i)!(n-(j-i))!} \]

\(f_i=\frac{1}{(i-1)!(n-i)!}\)\(g_i=\frac{1}{i!(n-i)!}\),則後半部分就是 \(f\cdot g\) (多項式乘法)的第 \(j\) 項。我們對 \(f,g\) 做 NTT 即可。

我們求 \(j,k\) 的貢獻時,是兩個不同的式子,所以各要做一次 NTT。總時間複雜度 \(O(n\log n)\)

參考程式碼

內含一個精細優化後的 NTT 模板(namespace SuperNTT),因為太長了,我將其單獨取出並附在後面:

實際提交時,建議使用快速輸入、輸出,詳見本部落格公告。

// problem: CF1444B
#include <bits/stdc++.h>
using namespace std;

#define pb push_back
#define mk make_pair
#define lob lower_bound
#define upb upper_bound
#define fi first
#define se second
#define SZ(x) ((int)(x).size())

typedef unsigned int uint;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int, int> pii;

template<typename T> inline void ckmax(T& x, T y) { x = (y > x ? y : x); }
template<typename T> inline void ckmin(T& x, T y) { x = (y < x ? y : x); }

namespace SuperNTT {
// ...
} // namespace SuperNTT

const int MAXN = 3e5;
const int MOD = 998244353;
inline int mod1(int x) { return x < MOD ? x : x - MOD; }
inline int mod2(int x) { return x < 0 ? x + MOD : x; }
inline void add(int &x, int y) { x = mod1(x + y); }
inline void sub(int &x, int y) { x = mod2(x - y); }
inline int pow_mod(int x, int i) {
	int y = 1;
	while (i) {
		if (i & 1) y = (ll)y * x % MOD;
		x = (ll)x * x % MOD;
		i >>= 1;
	}
	return y;
}

int fac[MAXN + 5], ifac[MAXN + 5];
inline int comb(int n, int k) {
	if (n < k) return 0;
	return (ll)fac[n] * ifac[k] % MOD * ifac[n - k] % MOD;
}
void facinit(int lim = MAXN) {
	fac[0] = 1;
	for (int i = 1; i <= lim; ++i) fac[i] = (ll)fac[i - 1] * i % MOD;
	ifac[lim] = pow_mod(fac[lim], MOD - 2);
	for (int i = lim - 1; i >= 0; --i) ifac[i] = (ll)ifac[i + 1] * (i + 1) % MOD;
}

int n, a[MAXN + 5];

int f[MAXN * 4 + 5], g[MAXN * 4 + 5], res[MAXN * 4 + 5];
int main() {
	facinit();
	cin >> n;
	for (int i = 1; i <= n * 2; ++i) {
		cin >> a[i];
	}
	sort(a + 1, a + n * 2 + 1);
	int ans = 0;
	
	for (int i = 0; i <= n; ++i) {
		if (i != 0) f[i] = (ll)ifac[i - 1] * ifac[n - i] % MOD;
		g[i] = (ll)ifac[i] * ifac[n - i] % MOD;
	}
	
	SuperNTT :: work(f, g, n + 1, n + 1, res);
	for (int j = n + 1; j <= n * 2; ++j) {
		add(ans, (ll)a[j] * fac[j - 1] % MOD * fac[2 * n - j] % MOD * res[j] % MOD);
	}
	for (int j = 1; j <= n; ++j) {
		sub(ans, (ll)a[j] * fac[j - 1] % MOD * fac[2 * n - j] % MOD * res[j] % MOD);
	}
	
	memset(f, 0, sizeof(f));
	memset(g, 0, sizeof(g));
	for (int i = 1; i <= n; ++i) {
		f[i] = (ll)ifac[i - 1] * ifac[n - i] % MOD;
	}
	for (int i = n + 1; i <= n * 2 + 1; ++i) {
		g[i] = (ll)ifac[i - n - 1] * ifac[2 * n + 1 - i] % MOD;
	}
	reverse(f, f + n + 1);
	SuperNTT :: work(f, g, n + 1, n * 2 + 2, res);
	for (int j = n + 1; j <= n * 2; ++j) {
		add(ans, (ll)a[j] * fac[j - 1] % MOD * fac[n * 2 - j] % MOD * res[n + j] % MOD);
	}
	for (int j = 1; j <= n; ++j) {
		sub(ans, (ll)a[j] * fac[j - 1] % MOD * fac[n * 2 - j] % MOD * res[n + j] % MOD);
	}
	cout << ans << endl;
	return 0;
}

NTT 模板(namespace SuperNTT):

typedef unsigned int uint;
typedef long long unsigned int uint64;

constexpr uint Max_size = 1 << 21 | 5;
constexpr uint g = 3, Mod = 998244353;

inline uint norm_2(const uint x)
{
	return x < Mod * 2 ? x : x - Mod * 2;
}

inline uint norm(const uint x)
{
	return x < Mod ? x : x - Mod;
}

struct Z
{
	uint v;
	Z() { }
	Z(const uint _v) : v(_v) { }
};

inline Z operator+(const Z &x1, const Z &x2) { return x1.v + x2.v < Mod ? x1.v + x2.v : x1.v + x2.v - Mod; }
inline Z operator-(const Z &x1, const Z &x2) { return x1.v >= x2.v ? x1.v - x2.v : x1.v + Mod - x2.v; }
inline Z operator*(const Z &x1, const Z &x2) { return static_cast<uint64>(x1.v) * x2.v % Mod; }
inline Z &operator*=(Z &x1, const Z &x2) { x1.v = static_cast<uint64>(x1.v) * x2.v % Mod; return x1; }

Z Power(Z Base, int Exp)
{
	Z res = 1;
	for (; Exp; Base *= Base, Exp >>= 1)
		if (Exp & 1)
			res *= Base;
	return res;
}

inline uint mf(uint x)
{
	return (static_cast<uint64>(x) << 32) / Mod;
}

int size;
uint w[Max_size], w_[Max_size];

inline uint mult_Shoup_2(const uint x, const uint y, const uint y_)
{
	uint q = static_cast<uint64>(x) * y_ >> 32;
	return x * y - q * Mod;
}

inline uint mult_Shoup(const uint x, const uint y, const uint y_)
{
	return norm(mult_Shoup_2(x, y, y_));
}

inline void init(int n)
{
	for (size = 2; size < n; size <<= 1)
		;
	Z pr = Power(g, (Mod - 1) / size);
	size >>= 1;
	w[size] = 1, w_[size] = (static_cast<uint64>(w[size]) << 32) / Mod;
	if (size <= 8)
	{
		for (int i = 1; i < size; ++i)
			w[size + i] = (w[size + i - 1] * pr).v, w_[size + i] = (static_cast<uint64>(w[size + i]) << 32) / Mod;
	}
	else
	{
		for (int i = 1; i < 8; ++i)
			w[size + i] = (w[size + i - 1] * pr).v, w_[size + i] = (static_cast<uint64>(w[size + i]) << 32) / Mod;
		pr *= pr, pr *= pr, pr *= pr;
		for (int i = 8; i < size; i += 8)
		{ 
			w[size + i + 0] = (w[size + i - 8] * pr).v, w_[size + i + 0] = (static_cast<uint64>(w[size + i + 0]) << 32) / Mod;
			w[size + i + 1] = (w[size + i - 7] * pr).v, w_[size + i + 1] = (static_cast<uint64>(w[size + i + 1]) << 32) / Mod;
			w[size + i + 2] = (w[size + i - 6] * pr).v, w_[size + i + 2] = (static_cast<uint64>(w[size + i + 2]) << 32) / Mod;
			w[size + i + 3] = (w[size + i - 5] * pr).v, w_[size + i + 3] = (static_cast<uint64>(w[size + i + 3]) << 32) / Mod;
			w[size + i + 4] = (w[size + i - 4] * pr).v, w_[size + i + 4] = (static_cast<uint64>(w[size + i + 4]) << 32) / Mod;
			w[size + i + 5] = (w[size + i - 3] * pr).v, w_[size + i + 5] = (static_cast<uint64>(w[size + i + 5]) << 32) / Mod;
			w[size + i + 6] = (w[size + i - 2] * pr).v, w_[size + i + 6] = (static_cast<uint64>(w[size + i + 6]) << 32) / Mod;
			w[size + i + 7] = (w[size + i - 1] * pr).v, w_[size + i + 7] = (static_cast<uint64>(w[size + i + 7]) << 32) / Mod;
		} 
	}
	for (int i = size - 1; i; --i)
		w[i] = w[i * 2], w_[i] = w_[i * 2];
	size <<= 1;
}

inline void DFT_fr_2(Z _A[], const int L)
{
	if (L == 1)
		return;
	uint *A = reinterpret_cast<uint *>(_A);
#define butterfly1(a, b)\
	do\
	{\
		uint _a = a, _b = b;\
		uint x = norm_2(_a + _b), y = norm_2(_a + Mod * 2 - _b);\
		a = x, b = y;\
	} while (0)
	if (L == 2)
	{
		butterfly1(A[0], A[1]);
		return;
	}
#define butterfly(a, b, _w, _w_)\
	do\
	{\
		uint _a = a, _b = b;\
		uint x = norm_2(_a + _b), y = mult_Shoup_2(_a + Mod * 2 - _b, _w, _w_);\
		a = x, b = y;\
	} while (0)
	if (L == 4)
	{
		butterfly1(A[0], A[2]);
		butterfly(A[1], A[3], w[3], w_[3]);
		butterfly1(A[0], A[1]);
		butterfly1(A[2], A[3]);
		return; 
	}
	for (int d = L >> 1; d != 4; d >>= 1)
		for (int i = 0; i != L; i += d << 1)
			for (int j = 0; j != d; j += 4)
			{
				butterfly(A[i + j], A[i + d + j], w[d + j], w_[d + j]);
				butterfly(A[i + j + 1], A[i + d + j + 1], w[d + j + 1], w_[d + j + 1]);
				butterfly(A[i + j + 2], A[i + d + j + 2], w[d + j + 2], w_[d + j + 2]);
				butterfly(A[i + j + 3], A[i + d + j + 3], w[d + j + 3], w_[d + j + 3]);
			}
	for (int i = 0; i != L; i += 8)
	{
		butterfly1(A[i], A[i + 4]);
		butterfly(A[i + 1], A[i + 5], w[5], w_[5]);
		butterfly(A[i + 2], A[i + 6], w[6], w_[6]);
		butterfly(A[i + 3], A[i + 7], w[7], w_[7]);
	}
	for (int i = 0; i != L; i += 8)
	{
		butterfly1(A[i], A[i + 2]);
		butterfly(A[i + 1], A[i + 3], w[3], w_[3]);
		butterfly1(A[i + 4], A[i + 6]);
		butterfly(A[i + 5], A[i + 7], w[3], w_[3]);
	}
	for (int i = 0; i != L; i += 8)
	{
		butterfly1(A[i], A[i + 1]);
		butterfly1(A[i + 2], A[i + 3]);
		butterfly1(A[i + 4], A[i + 5]);
		butterfly1(A[i + 6], A[i + 7]);
	}
#undef butterfly1
#undef butterfly
}

inline void IDFT_fr(Z _A[], const int L)
{
	if (L == 1)
		return;
	uint *A = reinterpret_cast<uint *>(_A);
#define butterfly1(a, b)\
	do\
	{\
		uint _a = a, _b = b;\
		uint x = norm_2(_a), t = norm_2(_b);\
		a = x + t, b = x + Mod * 2 - t;\
	} while (0)
	if (L == 2)
	{
		butterfly1(A[0], A[1]);
		A[0] = norm(norm_2(A[0])), A[0] = A[0] & 1 ? A[0] + Mod : A[0], A[0] /= 2;
		A[1] = norm(norm_2(A[1])), A[1] = A[1] & 1 ? A[1] + Mod : A[1], A[1] /= 2;
		return;
	}
#define butterfly(a, b, _w, _w_)\
	do\
	{\
		uint _a = a, _b = b;\
		uint x = norm_2(_a), t = mult_Shoup_2(_b, _w, _w_);\
		a = x + t, b = x + Mod * 2 - t;\
	} while (0)
	if (L == 4)
	{
		butterfly1(A[0], A[1]);
		butterfly1(A[2], A[3]);
		butterfly1(A[0], A[2]);
		butterfly(A[1], A[3], w[3], w_[3]);
		std::swap(A[1], A[3]);
		for (int i = 0; i != L; ++i)
		{
			uint64 m = -A[i] & 3;
			A[i] = norm((A[i] + m * Mod) >> 2);
		}
		return; 
	}
	for (int i = 0; i != L; i += 8)
	{
		butterfly1(A[i], A[i + 1]);
		butterfly1(A[i + 2], A[i + 3]);
		butterfly1(A[i + 4], A[i + 5]);
		butterfly1(A[i + 6], A[i + 7]);
	}
	for (int i = 0; i != L; i += 8)
	{
		butterfly1(A[i], A[i + 2]);
		butterfly(A[i + 1], A[i + 3], w[3], w_[3]);
		butterfly1(A[i + 4], A[i + 6]);
		butterfly(A[i + 5], A[i + 7], w[3], w_[3]);
	}
	for (int i = 0; i != L; i += 8)
	{
		butterfly1(A[i], A[i + 4]);
		butterfly(A[i + 1], A[i + 5], w[5], w_[5]);
		butterfly(A[i + 2], A[i + 6], w[6], w_[6]);
		butterfly(A[i + 3], A[i + 7], w[7], w_[7]);
	}
	for (int d = 8; d != L; d <<= 1)
		for (int i = 0; i != L; i += d << 1)
			for (int j = 0; j != d; j += 4)
			{
				butterfly(A[i + j], A[i + d + j], w[d + j], w_[d + j]);
				butterfly(A[i + j + 1], A[i + d + j + 1], w[d + j + 1], w_[d + j + 1]);
				butterfly(A[i + j + 2], A[i + d + j + 2], w[d + j + 2], w_[d + j + 2]);
				butterfly(A[i + j + 3], A[i + d + j + 3], w[d + j + 3], w_[d + j + 3]);
			}
#undef butterfly1
#undef butterfly
	std::reverse(A + 1, A + L);
	int k = __builtin_ctz(L);
	for (int i = 0; i != L; ++i)
	{
		uint64 m = -A[i] & (L - 1);
		A[i] = norm((A[i] + m * Mod) >> k);
	}
}

int N, M, L;
Z A[Max_size], B[Max_size];

void work(int f[], int g[], int n, int m, int res[]) {
	N = n; M = m;
	memset(A, 0, sizeof(A));
	memset(B, 0, sizeof(B));
	for(int i = 0; i < n; ++i) A[i].v = f[i];
	for(int i = 0; i < m; ++i) B[i].v = g[i];
	for (L = 2; L <= N + M - 2; L <<= 1)
		;
	init(L);
	
	DFT_fr_2(A, L), DFT_fr_2(B, L);
	for (int i = 0; i != L; ++i)
		A[i] *= B[i];
	IDFT_fr(A, L);
	
	for(int i = 0; i < n + m - 1; ++i) res[i] = A[i].v;
}

\(O(n^3)\) 做法片段:

sort(a + 1, a + n * 2 + 1);
int ans = 0;
for (int i = 1; i <= n; ++i) {
	for (int j = i; j <= 2 * n; ++j) {
		// a[j] -> p[i]
		for (int k = n - i + 1; k <= n * 2; ++k) {
			// a[k] -> q[i]
			if (a[j] == a[k]) continue;
			if (k < j) {
				if (k - 1 <= (i - 1) + n - i)
					add(ans, (ll)comb(k - 1, n - i) * comb(j - k - 1, (i - 1) - (k - 1 - (n - i))) % MOD * comb(n * 2 - j, n - i) % MOD * (a[j] - a[k]) % MOD);
			} else {
				if (j - 1 <= (i - 1) + n - i)
					add(ans, (ll)comb(j - 1, i - 1) * comb(k - j - 1, (n - i) - (j - 1 - (i - 1))) % MOD * comb(n * 2 - k, i - 1) % MOD * (a[k] - a[j]) % MOD);
			}
		}
	}
}
cout << ans << endl;

\(O(n^2)\) 做法片段:

sort(a + 1, a + n * 2 + 1);
int ans = 0;
for (int i = 1; i <= n; ++i) {
	for (int j = n + 1; j <= n * 2; ++j) {
		// k < j
		add(ans, (ll)comb(j - 1, i - 1) * comb(n * 2 - j, n - i) % MOD * a[j] % MOD);
	}
	for (int j = n; j >= 1; --j) {
		// j < k
		sub(ans, (ll)comb(n * 2 - j, n - i) * comb(j - 1, i - 1) % MOD * a[j] % MOD);
	}
	for (int k = n + 1; k <= n * 2; ++k) {
		// j < k
		add(ans, (ll)comb(k - 1, n - i) * comb(n * 2 - k, i - 1) % MOD * a[k] % MOD);
	}
	for (int k = n; k >= 1; --k) {
		// k < j
		sub(ans, (ll)comb(n * 2 - k, i - 1) * comb(k - 1, n - i) % MOD * a[k] % MOD);
	}
}