CF908D New Year and Arbitrary Arrangement 題解
阿新 • • 發佈:2022-06-01
\(0.\) 前言
有一天 \(Au\) 爺講期望都見到了此題,通過寫題解來加深理解。
\(1.\) 題意
將初始為空的序列的末尾給定概率新增 \(a\) 或 \(b\),當至少有 \(k\) 對 \(ab\) 時停止(注意是“對”,中間可以間隔字元),求 \(ab\) 期望對數。
\(2.\) 思路
通過檢視標籤 通過閱讀題面我們容易發現本題是一道期望 DP,但是本題的狀態並不很容易想到,設 \(f[i][j]\) 表示字首中有 \(i\) 個 \(a\),\(j\) 個 \(ab\) 停止後的期望個數,這樣發現轉移就容易了很多,不會被 \(a\) 和 \(b\) 糾纏不清,設 \(A = pa / (pa + pb)\)
若 \(i + j ⩾ k\),則再加一個 \(b\) 就會結束,此時的期望 \(ab\) 數是:
\[i + j + pa / pb \]故終止狀態為:
\[f[i][j] = i + j + pa / pb, i + j ⩾ k \]\(3.\) 解釋
(本塊主要針對 \(i + j + pa / pb\) 的推導,不感興趣可以跳過)
我一直疑惑 \(i + j + pa / pb\) 如何得出。
解釋一下,在字首有了 \(i\)
接下來的證明部分參考一粒夸克的部落格
首先是等差乘等比數列求和公式
\[(1):A=a+(a+p)×p+(a+2×b)×p^2+...+(a+n×b)×p^n \] \[(2):A×p=a×p+(a+b)×p^2+(a+2×b)×p^3+...+(a+n×b)×p^{n+1} \] \[(1)-(2):A×(1-p)=a+b×(p+p^2+p^3+...+p^n)-(a+n×b)p^{n+1} \] \[A×(1-p)=a+b×p×{1-p^n \over 1-p}-(a+n×b)×p^{n+1} \] \[A={a\over1-p}+b×{p-p^{n+1}\over(1-p)^2}-{(a+n×b)×p^{n+1}\over1-p} \]將公式代入無限和式
(這麼巨量\(\LaTeX\)我都打了,求贊)
\(4.\) 細節
- 由於 \(f[0][0]\) 會轉移到自己,遞迴記憶化會死迴圈,從 \(f[1][0]\) 開始算,當序列前有一堆 \(b\) 的情況沒有意義,可以跳到第一個 \(a\) 發生時開始算。初始狀態選取 \(f[1][0]\)。
- 當 \(a\) 與 \(ab\) 的個數相加已經大於 \(k\) 了,這是就不關心有多少 \(a\) 了,只需要有一個 \(b\) 就可以結束了,這樣可以把兩維都控制在 \(O(k)\) 的複雜度
\(5.\) 程式碼
這是一份逆推實現的程式碼:
#include<map>
#include<cmath>
#include<queue>
#include<vector>
#include<cstdio>
#include<string>
#include<cstring>
#include<iostream>
#include<algorithm>
#define int long long
using namespace std;
template<class T> inline void read(T &x){
x = 0; register char c = getchar(); register bool f = 0;
while(!isdigit(c)) f ^= c == '-', c = getchar();
while(isdigit(c)) x = x * 10 + c - '0', c = getchar();
if(f) x = -x;
}
template<class T> inline void print(T x){
if(x < 0) putchar('-'), x = -x;
if(x > 9) print(x / 10);
putchar('0' + x % 10);
}
const int N = 1010;
const int mod = 1e9 + 7;
int n, pa, pb, A, B, C;
int f[N][N];
inline int qpow(int a, int b){
int res = 1;
while(b){
if(b & 1) res = 1ll * res * a % mod;
a = 1ll * a * a % mod, b >>= 1;
}
return res;
}
inline int work(int x){
return qpow(x, mod - 2);
}
signed main(){
read(n), read(pa), read(pb);
A = 1ll * pa * work(pa + pb) % mod;
B = 1ll * pb * work(pa + pb) % mod;
C = 1ll * pa * work(pb) % mod;
for(int i = n; i >= 1; --i)
for(int j = n; j >= 0; --j){
if(i + j >= n) f[i][j] = (i + j + C) % mod;
else f[i][j] = (1ll * A * f[i + 1][j] % mod + 1ll * B * f[i][j + i] % mod) % mod;
}
print(f[1][0]), puts("");
return 0;
}
這是一份記搜實現的程式碼片段:
inline int dp(int i, int j){
if(i + j >= k) return (i + j + C) % mod;
if(~ f[i][j]) return f[i][j];
return (1ll * A * dp(i + 1, j) + 1ll * B * dp(i, j + i)) % mod;
}