1. 程式人生 > 其它 >AGC056E Cheese【概率期望,dp】

AGC056E Cheese【概率期望,dp】

給定長為 \(n\) 的圓周,定義沿順時針方向距離為 \(x\) 的位置的座標為 \(x\)。初始時座標 \(i+0.5\) 的位置上有一隻老鼠。

進行 \(n-1\) 次操作,每次操作以 \(a_i\%\) 的概率選擇 \(i\pod{0\le i<n}\),在座標 \(i\) 放上一塊乳酪,之後乳酪順時針移動,每次遇到沒吃過乳酪的老鼠,都有 \(1/2\) 的概率被吃掉。

\(\forall k\in[0,n)\),求最後沒吃乳酪的老鼠是座標位於 \(k+0.5\) 的老鼠的概率 \(\bmod 998244353\)

\(n\le 40\)\(a_i\ge 0\)\(\sum a_i=100\)


只考慮 \(k=n-1\) 的情況,分別做即可。

本題最關鍵的結論是,在確定了每個座標 \(i\) 被放乳酪的次數 \(c_i\) 之後,概率僅與乳酪經過 \(-0.1\) 的次數 \(x\) 有關,而與放乳酪的順序無關。

\(b_i\) 表示乳酪經過 \(i+0.1\) 的次數,則 \(b_i=x-i+\sum_{j=0}^ic_j\),而整個過程實際上是一個乳酪與老鼠的匹配,所以與放乳酪的順序無關,乘上個可重排列方案數即可,而老鼠決策的概率為 \(2^{-x}\prod_{i=0}^{n-2}(1-2^{-b_i})\),表示第 \(0,1,\cdots,n-2\) 個老鼠至少要吃掉一個乳酪(雖然實際上只有第一次是吃了的),而第 \(n-1\)

個老鼠不能吃乳酪。

而如果任意取 \(c_i\)\(x\)\(b_i\) 有可能是負的,但因為 \(b_{i+1}-b_i\ge -1\) 所以此時必有一個 \(b_i=0\),算出的概率即為 \(0\),因此不需考慮這種情況。

最後答案即為 \((n-1)!\sum2^{-x}\prod_{i=0}^{n-2}(1-2^{i-\sum_{j\le i}c_j}\cdot 2^{-x})\prod_{i=0}^{n-1}\frac{a_i^{c_i}}{c_i!}\),要對所有 \(x\ge 0\) 求和所以應該表示為 \(2^{-x}\) 的多項式,對 \(p\) 次項係數乘上 \(1/(1-2^{-p})\)

然後加起來即可。直接 dp 維護,單次計算時間複雜度 \(O(n^4)\)

然而你會發現你輸出的恰好是答案的 \(2\) 倍,因為對於任意一個合法方案,把它加上一個獨立的圈之後的概率也被算了進去,而加一個圈之後概率變為原來的 \(1/2\)(不能被第 \(n-1\) 個老鼠吃掉),所以算出來的是答案的 \(1+1/2+1/4+\cdots=2\) 倍,最後除以 \(2\) 即可。

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 42, mod = 998244353, iv100 = 828542813;
int ksm(int a, int b){
    if(b < 0) b += mod - 1;
    int res = 1;
    for(;b;b >>= 1, a = (LL)a * a % mod)
        if(b & 1) res = (LL)res * a % mod;
    return res;
}
void qmo(int &x){x += x >> 31 & mod;}
int n, pr[N], a[N], fac[N], inv[N], f[N][N], g[N][N], val[N];
void solve(){
    memset(f, 0, sizeof f); **f = *val = 1;
    for(int i = 0;i < n;++ i){
        for(int j = 1;j < n;++ j) val[j] = (LL)val[j-1] * inv[j] % mod * a[i] % mod;
        memset(g, 0, sizeof g);
        for(int j = 0;j < n;++ j)
            for(int k = 0;j + k < n;++ k)
                for(int l = 0;l <= i;++ l)
                    g[j+k][l] = (g[j+k][l] + (LL)f[j][l] * val[k]) % mod;
        memset(f, 0, sizeof f);
        for(int j = 0;j < n;++ j){
            int coe = ksm(2, i - j);
            if(i == n-1) for(int k = 0;k < n;++ k) f[j][k+1] = (LL)coe * g[j][k] % mod;
            else for(int k = 0;k <= i;++ k){qmo(f[j][k] += g[j][k] - mod); f[j][k+1] = mod - (LL)coe * g[j][k] % mod;}
        }
    }
    int res = 0;
    for(int i = 1, pw = 1;i <= n;++ i){
        res = (res + (LL)pw * f[n-1][i] % mod * ksm(2*pw-1, mod-2)) % mod;
        qmo(pw += pw - mod);
    }
    printf("%lld ", (LL)res * fac[n-1] % mod);
}
int main(){
    ios::sync_with_stdio(false);
    cin >> n; *fac = fac[1] = inv[1] = 1;
    for(int i = 2;i <= n;++ i){
        fac[i] = (LL)fac[i-1] * i % mod;
        inv[i] = mod - (LL)mod / i * inv[mod % i] % mod;
    }
    for(int i = 0;i < n;++ i){cin >> pr[i]; pr[i] = (LL)pr[i] * iv100 % mod;}
    for(int i = 0;i < n;++ i){
        for(int j = 0;j <= i;++ j) a[n+j-i-1] = pr[j];
        for(int j = i+1;j < n;++ j) a[j-i-1] = pr[j];
        solve();
    }
}