@codeforces - [email protected] Bandit Blues
目錄
@[email protected]
求有多少個長度為 n 的排列,從左往右遍歷有 a 個數比之前遍歷的所有數都大,從右往左遍歷有 b 個數比之前遍歷的所有數都大。
模 998244323。
input
一行三個整數 n, a, b。1 ≤ n ≤ 10^5,0 ≤ A, B ≤ n。
output
輸出排列數模 998244353。
sample input
5 2 2
sample output
22
@[email protected]
@part - [email protected]
首先從左往右和從右往左都會在最大值的地方停下來。
我們列舉最大值的位置,並記 dp(i, j) 表示 i 個元素順序遍歷有 j 個符合要求的元素的方案數。
則:
\[ans=\sum_{i=1}^{n}dp(i-1, a-1)*dp(n-i, b-1)*C(n-1,i-1)\]
為什麼要減去 1 呢?因為我們的最後一個元素一定是最大值。
考慮怎麼求解 dp(i, j)。我們為了避免繁雜的列舉,直接考慮 i 個元素中最小的那個元素的位置。如果最小的元素是第一個,則它一定被計算進去,剩下的狀態變為 dp(i-1, j-1);否則,它一定不會被計算進去,就可以刪除它,變為 dp(i-1, j)。
故:
\[dp(i, j) = dp(i-1, j-1) + (i-1)*dp(i-1, j-1)\]
如果你對組合數學足夠熟悉,就會發現上面那個式子其實是第一類斯特林數 \(s(i, j)\) 的遞推式。
考慮其組合意義。如果我們最後符合要求的數為 \(a_{p_1}, a_{p_2}, \dots , a_{p_j}\),則一定有 \(a_{p_1+1\dots p_2-1} < a_{p_1}\)
如果我們把 \(a_{p_1...p_2-1}\) 看成一個整體,則這個整體對答案的貢獻其實是圓排列——每個排列都必須要保證 \(a_{p_1}\) 在第一個位置,就像是某個圓排列將 \(a_{p_1}\) 旋轉到第一個位置。
如果我們確定了數放在哪一個圓排列中,則圓排列之間的相對位置是唯一的,因為我們必須要滿足 \(a_{p_1} < a_{p_2} < \dots < a_{p_j}\)。也就是說最後的方案數就是將 i 個數分成 j 個圓排列的方案數——即第一類斯特林數的定義。
既然扯到了組合意義,那麼最初那個列舉最大值的位置可不可以直接用組合數學的方法來搞定了?
我們可以這樣理解:先將 n-1(除最大值以外)個數分成 a+b-2 個圓排列,再將這 a+b-2 個圓排列黑白染色,選擇 a-1 個染黑色(放在最大值左邊),剩下的染白色(放在最大值右邊)。則:
\[ans = s(n-1, a+b-2)*C(a+b-2,a-1)\]
@part - [email protected]
【接下來只是來講講怎麼 O(nlogn) 求解第一類斯特林數的,如果你已經很熟悉了可以直接跳過這一段】
我們根據這樣一個公式進行求解:
\[x(x+1)(x+2)\dots(x+n-1)=\sum_{i=0}^{n}s(n,i)x^i\]
有些類似於二項式定理。可以根據對最後一項是選擇 x 還是 n-1 得到和我們遞推公式一樣的結果。
我們利用倍增解決這一問題。
記 \(f_n(x)=\prod_{i=0}^{n-1}(x+i) = a_0+a_1x+\dots+a_{n-1}x^{n-1}\)。
則 \(f_{2n}(x) = f(x)*f(x+n)\),\(f_{2n+1}(x)=f(x)*f(x+n)*(x+2n)\)。
如果已知 \(f(x+n)\),則可以用 fft 快速計算多項式乘法。
考慮怎麼已知 \(f_n(x)\) 求 \(f_n(x+n)\)。將 \(f_n(x+n)\) 的式子寫出來:
\[f_n(x+n)=\sum_{i=0}^{n-1}a_i(x+n)^i\]
二項式展開:
\[f_n(x+n)=\sum_{i=0}^{n-1}a_i(\sum_{j=0}^{i}C(i,j)*n^j*x^{i-j})\]
把內層的求和去掉:
\[f_n(x+n)=\sum_{0\le j\le i<n}a_i*C(i,j)*n^j*x^{i-j}\]
把組合數拆成階乘形式,並適當整理:
\[f_n(x+n)=\sum_{0\le j\le i<n}(a_i*i!)*(\frac{n^j}{j!})*(\frac{x^{i-j}}{(i-j)!})\]
如果記 \(A_i = a_i*i!\),\(B_i = \frac{n^j}{j!}\),則我們相當於是要求解 A 與 B 的減法卷積。將 A 翻轉一下就可以正常用 fft 做加法卷積,然後把結果再翻轉回來即可。
@accepted [email protected]
注意一些該特判的地方還是要特判。
#include<cstdio>
#include<algorithm>
using namespace std;
const int G = 3;
const int MOD = 998244353;
const int MAXN = 400000;
int pow_mod(int b, int p) {
int ret = 1;
while( p ) {
if( p & 1 ) ret = 1LL*ret*b%MOD;
b = 1LL*b*b%MOD;
p >>= 1;
}
return ret;
}
int fct[MAXN + 5], inv[MAXN + 5];
void ntt(int *A, int n, int type) {
for(int i=0,j=0;i<n;i++) {
if( i < j ) swap(A[i], A[j]);
for(int l=(n>>1);(j^=l)<l;l>>=1);
}
for(int s=2;s<=n;s<<=1) {
int t = (s>>1);
int u = (type == 1) ? pow_mod(G, (MOD-1)/s) : pow_mod(G, (MOD-1)-(MOD-1)/s);
for(int i=0;i<n;i+=s) {
int p = 1;
for(int j=0;j<t;j++,p=1LL*p*u%MOD) {
int x = A[i+j], y = 1LL*A[i+j+t]*p%MOD;
A[i+j] = (x + y)%MOD, A[i+j+t] = (x + MOD - y)%MOD;
}
}
}
if( type == -1 ) {
int k = 1LL*fct[n-1]*inv[n]%MOD;
for(int i=0;i<n;i++)
A[i] = 1LL*A[i]*k%MOD;
}
}
void init() {
fct[0] = 1;
for(int i=1;i<=MAXN;i++)
fct[i] = 1LL*fct[i-1]*i%MOD;
inv[MAXN] = pow_mod(fct[MAXN], MOD - 2);
for(int i=MAXN-1;i>=0;i--)
inv[i] = 1LL*inv[i+1]*(i+1)%MOD;
}
int comb(int n, int m) {
return 1LL*fct[n]*inv[m]%MOD*inv[n-m]%MOD;
}
int tmp1[MAXN + 5], tmp2[MAXN + 5], tmp3[MAXN + 5];
void sterling1(int *A, int n) {
if( !n ) {
A[0] = 1;
return ;
}
int m = n/2, pw = 1, len;
sterling1(A, m);
for(len = 1;len <= n;len <<= 1);
for(int i=0;i<=m;i++) tmp1[m - i] = 1LL*fct[i]*A[i]%MOD;
for(int i=0;i<=m;i++) tmp2[i] = 1LL*inv[i]*pw%MOD, pw=1LL*pw*m%MOD;
for(int i=m+1;i<len;i++) tmp1[i] = tmp2[i] = 0;
ntt(tmp1, len, 1), ntt(tmp2, len, 1);
for(int i=0;i<len;i++) tmp1[i] = 1LL*tmp1[i]*tmp2[i]%MOD;
ntt(tmp1, len, -1);
for(int i=0;i<=m;i++) tmp3[m - i] = 1LL*tmp1[i]*inv[m - i]%MOD;
for(int i=0;i<=m;i++) tmp1[i] = A[i];
for(int i=m+1;i<len;i++) tmp1[i] = tmp3[i] = 0;
if( n & 1 ) {
tmp2[1] = 1, tmp2[0] = (MOD + n - 1);
for(int i=2;i<len;i++) tmp2[i] = 0;
}
else {
tmp2[0] = 1;
for(int i=1;i<len;i++) tmp2[i] = 0;
}
ntt(tmp1, len, 1), ntt(tmp2, len, 1), ntt(tmp3, len, 1);
for(int i=0;i<len;i++) tmp1[i] = 1LL*tmp1[i]*tmp2[i]%MOD*tmp3[i]%MOD;
ntt(tmp1, len, -1);
for(int i=0;i<=n;i++) A[i] = tmp1[i];
}
int f[MAXN + 5];
int main() {
int n, a, b; init();
scanf("%d%d%d", &n, &a, &b);
if( a + b > n + 1 || a == 0 || b == 0 ) {
printf("%d\n", 0);
return 0;
}
sterling1(f, n - 1);
/*
for(int i=0;i<=n-1;i++)
printf("%d ", f[i]);
puts("");
*/
printf("%lld\n", 1LL*f[a + b - 2]*comb(a + b - 2, a - 1)%MOD);
}
@[email protected]
寫程式的時候突然發現斯特林數的簡寫是 STL。
我就說用998244353這個模數肯定是ntt嘛。
不要忘記乘上 \(\frac{1}{(i-j)!}\)。
奇數長度的還要多乘一個多項式。
邊界當 n = 0 的時候,返回一個常數 1。