GMOJ 6861 最終作戰 題解
阿新 • • 發佈:2020-11-16
題目大意
求長度為\(n\)且相鄰元素之差的絕對值大於一的數對個數不超過\(k\)的排列個數。
做法
DP莫得前途
那麼我們考慮生成函式。
考慮把原排列劃分成若干連續段(上升或下降)
則對於長度為1的連續段,指定值後只有一種排列方法。
對於長度大於1的連續段,指定值域後有兩種排列方式,一種遞增一種遞減
那麼我們寫出關於連續段長度的生成函式
然後我們考慮把這些連續段拼起來,那麼選出來的第一段\(A_1\)
注意我們欽定了他們對應哪些值,就勢必要考慮他們的位置,也就是要乘上連續段個數的階乘,即 \[\begin{aligned} G(x) & = k! [x^n](\frac{2x}{1-x}-x)^k \\ & = k! [x^{n-k}](\frac{2}{1-x} - 1)^k \qquad (提取x^k) \\ & = k! \sum_{i=0}^{k} \tbinom{k}{i} 2^i (-1)^{k-i} [x^{n-k}] (\frac{1}{1-x})^i \qquad (二項式展開) \\ & = k! \sum_{i=0}^{k} \tbinom{k}{i} 2^i (-1)^{k-i} [x^{n-k}] (\sum_{j\ge 0} x^j)^i \qquad (將後半段生成函式轉為一般形式) \\ & = k! \sum_{i=0}^{k} \tbinom{k}{i} 2^i (-1)^{k-i} \tbinom{n-k+i-1}{i-1} \qquad (考慮一般形式下後半段的意義,言下之意就是把n-k個相同物品放入i個不同盒子中) \\ & = (k!)^2 (n-k)! \sum_{i=0}^{k} \frac{2^i}{i!(i-1)!} \cdot \frac{(n-(k-i)-1)!(-1)^{k-i}}{(k-i)!} \qquad (展開組合數,簡單地變換) \\ \end{aligned} \]
那麼這個我們可以用卷積來做。
注意\(i=0\)時只有\(n=k\)才有一個\(1\)的貢獻。
發現這個樣子會算重。那麼我們設不算重的恰好有\(i+1\)個連續段的排列個數為\(f_i\),同時設會算重的恰好有\(i+1\)個連續段的排列個數為\(g_i\)。
考慮每個\(f_i\)會被\(g_j\)算多少次,那麼就是在\(n-1-i\)個差為1的間隔中選出\(j-i\)個重複計算,即
發現原式並不符合二項式反演的樣子,所以我們乘上\(\frac{i!j!}{i!j!}\)再移項。
\[g_i \cdot i!(n-1-i)! = \sum_{0 \le j \le i} \tbinom{i}{j} f_j \cdot (n-1-j)!j! \]那麼我們愉快地二項式反演+卷積即可。
答案就是\(\sum_{i=0}^{k} f_i\)
程式碼
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
namespace my {
typedef long long ll;
const int maxn=200000, size=524288, mods=998244353;
ll fact[maxn+1], invf[maxn+1], pow2[maxn+1];
ll qpower(ll a, int n) {
ll s=1;
for (; n; n/=2) {
if (n&1) s=s*a%mods;
a=a*a%mods;
}
return s;
}
ll invsize=qpower(size, mods-2);
void initFact(int n) {
fact[0] = 1;
for (int i=1; i<=n; i++) fact[i]=fact[i-1]*i%mods;
invf[n] = qpower(fact[n], mods-2);
for (int i=n; i; i--) invf[i-1]=invf[i]*i%mods;
}
void initPow(int n) {
pow2[0] = 1;
for (int i=1; i<=n; i++) pow2[i] = pow2[i-1]*2%mods;
}
ll c(int m, int n) {return fact[m]*invf[n]%mods*invf[m-n]%mods;}
ll getW(int n) {
const int g=3;
return qpower(g, (mods-1)/n);
}
void ntt(ll a[], int n, bool inv) {
static int pos[size];
int l=0;
for (; (1<<l)<n; l++);
l--;
for (int i=1; i<n; i++) {
pos[i] = (pos[i>>1]>>1)|((i&1)<<l);
if (i<pos[i]) swap(a[i], a[pos[i]]);
}
for (int len=1; len<n; len*=2) {
ll t=getW(len*2);
if (inv) t=qpower(t, mods-2);
for (int i=0; i<n; i+=len*2) {
ll w=1, x, y;
for (int j=0; j<len; j++) {
x=a[i+j], y=a[i+len+j]*w%mods;
a[i+j] = (x+y)%mods;
a[i+len+j] = (x+mods-y)%mods;
w=w*t%mods;
}
}
}
}
int main() {
freopen("fight.in", "r", stdin);
freopen("fight.out", "w", stdout);
int n, k;
scanf("%d %d", &n, &k);
initFact(maxn);
initPow(maxn);
static ll a[size], b[size], g[size], f[size];
for (int i=0; i<=n; i++) {
if (i) a[i] = pow2[i]*invf[i]%mods*invf[i-1]%mods;
if (i<n) b[i] = (mods+fact[n-i-1]*invf[i]%mods*(i%2 ? -1 : 1))%mods;
}
ntt(a, size, false);
ntt(b, size, false);
for (int i=0; i<size; i++) g[i] = a[i]*b[i]%mods;
ntt(g, size, true);
for (int i=0; i<size; i++) g[i] = g[i]*invsize%mods;
for (int i=0; i<=n; i++) g[i]=g[i]*fact[i]%mods*fact[i]%mods*fact[n-i]%mods;
g[n] = (g[n]+(n%2 ? -1 : 1)+mods)%mods;
for (int i=0; i<n; i++) g[i]=g[i+1]*fact[i]%mods*invf[n-1-i]%mods*invf[i]%mods;
for (int i=n; i<size; i++) g[i]=0;
memset(a, 0, sizeof a);
for (int i=0; i<n; i++) a[i]=(mods+(i%2 ? -1 : 1)*invf[i])%mods;
ntt(g, size, false);
ntt(a, size, false);
for (int i=0; i<size; i++) f[i]=g[i]*a[i]%mods;
ntt(f, size, true);
for (int i=0; i<size; i++) f[i]=f[i]*invsize%mods;
for (int i=0; i<n; i++) f[i]=f[i]*invf[n-1-i]%mods;
ll ans=0;
for (int i=0; i<=k && i<n; i++) ans = (ans+f[i])%mods;
printf("%lld\n", ans);
fclose(stdin);
fclose(stdout);
return 0;
}
};
int main() {return my::main();}