1. 程式人生 > 實用技巧 >GMOJ 6861 最終作戰 題解

GMOJ 6861 最終作戰 題解

題目大意

求長度為\(n\)且相鄰元素之差的絕對值大於一的數對個數不超過\(k\)的排列個數。

做法

DP莫得前途
那麼我們考慮生成函式。

考慮把原排列劃分成若干連續段(上升或下降)
則對於長度為1的連續段,指定值後只有一種排列方法。
對於長度大於1的連續段,指定值域後有兩種排列方式,一種遞增一種遞減
那麼我們寫出關於連續段長度的生成函式

\[\begin{aligned} A(x) & = x+\sum_{i \ge 2} 2x^i \\ & = \frac{2x}{1-x}-x \\ \end{aligned} \]

然後我們考慮把這些連續段拼起來,那麼選出來的第一段\(A_1\)

我們分配\([1, |A_1|]\),第二段\(A_2\)我們分配\([|A_1|+1, |A_1|+|A_2|]\)……以此類推。
注意我們欽定了他們對應哪些值,就勢必要考慮他們的位置,也就是要乘上連續段個數的階乘,即

\[\begin{aligned} G(x) & = k! [x^n](\frac{2x}{1-x}-x)^k \\ & = k! [x^{n-k}](\frac{2}{1-x} - 1)^k \qquad (提取x^k) \\ & = k! \sum_{i=0}^{k} \tbinom{k}{i} 2^i (-1)^{k-i} [x^{n-k}] (\frac{1}{1-x})^i \qquad (二項式展開) \\ & = k! \sum_{i=0}^{k} \tbinom{k}{i} 2^i (-1)^{k-i} [x^{n-k}] (\sum_{j\ge 0} x^j)^i \qquad (將後半段生成函式轉為一般形式) \\ & = k! \sum_{i=0}^{k} \tbinom{k}{i} 2^i (-1)^{k-i} \tbinom{n-k+i-1}{i-1} \qquad (考慮一般形式下後半段的意義,言下之意就是把n-k個相同物品放入i個不同盒子中) \\ & = (k!)^2 (n-k)! \sum_{i=0}^{k} \frac{2^i}{i!(i-1)!} \cdot \frac{(n-(k-i)-1)!(-1)^{k-i}}{(k-i)!} \qquad (展開組合數,簡單地變換) \\ \end{aligned} \]

那麼這個我們可以用卷積來做。
注意\(i=0\)時只有\(n=k\)才有一個\(1\)的貢獻。

發現這個樣子會算重。那麼我們設不算重的恰好有\(i+1\)個連續段的排列個數為\(f_i\),同時設會算重的恰好有\(i+1\)個連續段的排列個數為\(g_i\)
考慮每個\(f_i\)會被\(g_j\)算多少次,那麼就是在\(n-1-i\)個差為1的間隔中選出\(j-i\)個重複計算,即

\[\begin{aligned} g_i & = \sum_{0\le j \le i} \tbinom{n-1-j}{i-j} f_j \\ & = \sum_{0 \le j \le i} \frac{(n-1-j)!i!j!}{(i-j)!(n-1-i)!i!j!} f_j \\ \end{aligned} \]

發現原式並不符合二項式反演的樣子,所以我們乘上\(\frac{i!j!}{i!j!}\)再移項。

\[g_i \cdot i!(n-1-i)! = \sum_{0 \le j \le i} \tbinom{i}{j} f_j \cdot (n-1-j)!j! \]

那麼我們愉快地二項式反演+卷積即可。
答案就是\(\sum_{i=0}^{k} f_i\)

程式碼

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

namespace my {
    typedef long long ll;
    const int maxn=200000, size=524288, mods=998244353;

    ll fact[maxn+1], invf[maxn+1], pow2[maxn+1];

    ll qpower(ll a, int n) {
        ll s=1;
        for (; n; n/=2) {
            if (n&1) s=s*a%mods;
            a=a*a%mods;
        }
        return s;
    }

    ll invsize=qpower(size, mods-2);

    void initFact(int n) {
        fact[0] = 1;
        for (int i=1; i<=n; i++) fact[i]=fact[i-1]*i%mods;
        invf[n] = qpower(fact[n], mods-2);
        for (int i=n; i; i--) invf[i-1]=invf[i]*i%mods;
    }

    void initPow(int n) {
        pow2[0] = 1;
        for (int i=1; i<=n; i++) pow2[i] = pow2[i-1]*2%mods;
    }

    ll c(int m, int n) {return fact[m]*invf[n]%mods*invf[m-n]%mods;}

    ll getW(int n) {
        const int g=3;
        return qpower(g, (mods-1)/n);
    }

    void ntt(ll a[], int n, bool inv) {
        static int pos[size];
        int l=0;
        for (; (1<<l)<n; l++);
        l--;
        for (int i=1; i<n; i++) {
            pos[i] = (pos[i>>1]>>1)|((i&1)<<l);
            if (i<pos[i]) swap(a[i], a[pos[i]]);
        }

        for (int len=1; len<n; len*=2) {
            ll t=getW(len*2);
            if (inv) t=qpower(t, mods-2);
            for (int i=0; i<n; i+=len*2) {
                ll w=1, x, y;
                for (int j=0; j<len; j++) {
                    x=a[i+j], y=a[i+len+j]*w%mods;
                    a[i+j] = (x+y)%mods;
                    a[i+len+j] = (x+mods-y)%mods;
                    w=w*t%mods;
                }
            }
        }
    }

    int main() {
        freopen("fight.in", "r", stdin);
        freopen("fight.out", "w", stdout);

        int n, k;
        scanf("%d %d", &n, &k);

        initFact(maxn);
        initPow(maxn);
        static ll a[size], b[size], g[size], f[size];
        for (int i=0; i<=n; i++) {
            if (i) a[i] = pow2[i]*invf[i]%mods*invf[i-1]%mods;
            if (i<n) b[i] = (mods+fact[n-i-1]*invf[i]%mods*(i%2 ? -1 : 1))%mods;
        }
        ntt(a, size, false);
        ntt(b, size, false);
        for (int i=0; i<size; i++) g[i] = a[i]*b[i]%mods;
        ntt(g, size, true);
        for (int i=0; i<size; i++) g[i] = g[i]*invsize%mods;
        for (int i=0; i<=n; i++) g[i]=g[i]*fact[i]%mods*fact[i]%mods*fact[n-i]%mods;
        g[n] = (g[n]+(n%2 ? -1 : 1)+mods)%mods;

        for (int i=0; i<n; i++) g[i]=g[i+1]*fact[i]%mods*invf[n-1-i]%mods*invf[i]%mods;
        for (int i=n; i<size; i++) g[i]=0;
        memset(a, 0, sizeof a);
        for (int i=0; i<n; i++) a[i]=(mods+(i%2 ? -1 : 1)*invf[i])%mods;

        ntt(g, size, false);
        ntt(a, size, false);
        for (int i=0; i<size; i++) f[i]=g[i]*a[i]%mods;
        ntt(f, size, true);
        for (int i=0; i<size; i++) f[i]=f[i]*invsize%mods;

        for (int i=0; i<n; i++) f[i]=f[i]*invf[n-1-i]%mods;
        ll ans=0;
        for (int i=0; i<=k && i<n; i++) ans = (ans+f[i])%mods;
        printf("%lld\n", ans);

        fclose(stdin);
        fclose(stdout);
        return 0;
    }
};

int main() {return my::main();}