1. 程式人生 > 其它 >洛谷 「P6475 [NOI Online #2 入門組] 建設城市」

洛谷 「P6475 [NOI Online #2 入門組] 建設城市」

洛谷 「P6475 [NOI Online #2 入門組] 建設城市」

傳送門

\(\texttt{Description}\)

求滿足如下條件的序列 \(a\) 的數量:

  • 長度為 \(2n\)

  • \(\forall i\in[1,n],a_i\in[1,m]\)\(a_i\) 為正整數。

  • \(n\) 項單調不降,後 \(n\) 項單調不增。

  • 要求 \(a_x=a_y\)

答案對 \(998244353\) 取模。

\(\texttt{Data Range:}1\le x<y\le2n,1\le n,m\le10^5\)

\(\texttt{Solution}\)

分兩種情況討論。

  • \(x\)\(y\) 在異側,即 \(x\le n,y>n\)


    首先列舉 \(x\)\(y\) 的值 \(i\),然後分四段來看。
    第一段是 \(1\sim x-1\),第二段是 \(x+1\sim n\),第三段是 \(n+1\sim y-1\),第四段是 \(y+1\sim2n\)
    對於第一段,需要單調不降且範圍是 \(1\sim i\),那麼可以看成在這些數之間插 \(i-1\) 塊板,將其分為 \(i\) 段,第 \(1\) 段代表值為 \(1\) 的,以此類推。
    這樣就滿足了單調不降的限制,我們又知道插板法的公式,所以第一段的答案就求出來了,第二、三、四段同理。
    用乘法原理把四段答案相乘,再用加法原理把每一次枚舉出的答案相加即可。

  • \(x\)

    \(y\) 在同側,即 \(x\le n,y\le n\)\(x>n,y>n\)
    這裡可以繼續沿用上面的 trick。

於是沒了 qwq

\(\texttt{Code}\)

#include <cstdio>
#include <iostream>
using namespace std;
typedef long long ll;
const ll mod = 998244353;

ll ans, fac[300005];
inline void init(int n) {
	fac[0] = 1;
	for (int i = 1; i <= n; i++) fac[i] = fac[i - 1] * i % mod;
}
inline ll quick_pow(ll a, ll k, ll p) {
	ll res = 1;
	a %= p;
	while (k) {
		if (k & 1) res = res * a % p;
		a = a * a % p;
		k >>= 1;
	}
	return res;
}
inline ll inv(ll a, ll p) {return quick_pow(a, p - 2, p);}
inline ll get_C(int n, int m, ll p) {
	if (n < m) return 0;
	return fac[n] * inv(fac[m] * fac[n - m], p) % p;
}

int main() {
	init(3e5);
	int m, n, x, y;
	scanf("%d %d %d %d", &m, &n, &x, &y);
	if ((x <= n) ^ (y <= n)) {
		for (int i = 1; i <= m; i++) {
			ll tmp1 = get_C(x + i - 2, i - 1, mod);
			ll tmp2 = get_C(n - x + m - i, m - i, mod);
			ll tmp3 = get_C(y - n - 1 + m - i, m - i, mod);
			ll tmp4 = get_C(n * 2 - y + i - 1, i - 1, mod);
			ans = (ans + tmp1 * tmp2 % mod * tmp3 % mod * tmp4 % mod) % mod;
		}
		printf("%lld", ans);
		return 0;
	}
	if (x <= n && y <= n) {
		for (int i = 1; i <= m; i++) {
			ll tmp1 = get_C(x + i - 2, i - 1, mod);
			ll tmp2 = get_C(n - y + m - i, m - i, mod);
			ll tmp3 = get_C(n + m - 1, m - 1, mod);
			ans = (ans + tmp1 * tmp2 % mod * tmp3 % mod) % mod;
		}
		printf("%lld", ans);
		return 0;
	}
	for (int i = 1; i <= m; i++) {
		ll tmp1 = get_C(n + m - 1, m - 1, mod);
		ll tmp2 = get_C(x - n + m - i - 1, m - i, mod);
		ll tmp3 = get_C(n * 2 - y + i - 1, i - 1, mod);
		ans = (ans + tmp1 * tmp2 % mod * tmp3 % mod) % mod;
	}
	printf("%lld", ans);
	return 0;
}