【agc019F】Yes or No
Description
給你\(n+m\)個詢問,其中\(n\)個的答案是\(Yes\),\(m\)個的答案是\(No\),現在依次回答這些詢問,每回答一個詢問就告訴你聽你回答對了還是沒對,求最優策略下答對題目期望數量對\(998244353\)取模
Solution
個人感覺很棒的一題qwq
首先我們可以設出一個無腦dp:\(f[n][m]\)表示有\(n\)個\(Yes\)和\(m\)個\(No\)情況下的答案,策略的話當然是哪個剩的多就猜哪個,一樣多隨便猜一個,那麼我們可以得到轉移:
\[ f[n][m]=\begin{cases} \frac{n}{n+m}(f[n-1][m]+1)+\frac{m}{n+m}f[n][m-1]&(n>m)\\ \\ \frac{m}{n+m}(f[n][m-1]+1)+\frac{n}{n+m}f[n-1][m]&(n<m)\\ \\ 上面隨便選一個(其實就是隨便亂猜有\frac{1}{2}的概率有1的貢獻)&(n=m) \end{cases} \]
然後我們可以將這個東西放到一個。。座標系裡面,橫座標對應\(n\),縱座標對應\(m\),那麼一種回答的方案就相當於從\((n,m)\)出發到\((0,0)\)的一條路徑,考慮畫一條\(y=x\)的直線,整個座標系被這條線分成了兩大部分,線上方的點都滿足\(y>x\),下方則是\(x>y\),那麼放回上面的式子裡面,先不看概率只看貢獻,會發現在下方橫向的路徑是有貢獻的,上方縱向的路徑是有貢獻的
為了方便下面的描述,我們不妨令\(n>=m\),因為我們的策略中如果一樣多就隨便猜一個,所以從對角線上點轉移出來的答案應該還要乘上\(\frac{1}{2}\)(隨便猜有\(\frac{1}{2}\)
先考慮比較簡單的\(n=m\)的情況:考慮一條從\((n,n)\)到\((0,0)\)的不碰到對角線的路徑,這樣的一種方案中所有的邊的貢獻都是確定的可以直接計算,會發現不管怎麼走,每條路徑一定會有\(n\)的貢獻
那麼再看\(n>m\)的情況:考慮一條從\((n,m)\)到\((0,0)\)的路徑(可以經過對角線),我們按照觸碰對角線的節點將這條路徑劃分成若干個部分,除了第一部分(也就是從\((n,m)\)走到碰到的第一個對角線上的點的這段)以外,其他部分都可以看成是從對角線上某一個點出發,中途不經過對角線,在對角線上某個點結束的一段路程,其實也就是我們的\(n=m\)
所以,我們可以得到一個結論:確定的貢獻為\(max(n,m)\),接下來真正受概率影響的就只有那些對角線上的點的貢獻了
而這些點的貢獻其實也很好計算,只要有一條路徑經過對角線上的一個點,那麼不管是橫著走還是豎著走的,都有\(\frac{1}{2}\)的概率獲得\(1\)的貢獻,所以我們只要對於對角線上面的每一個點計算經過它的方案數,然後除以總的路徑數量,再乘上\(\frac{1}{2}\)即可
mark:沒事把這種-1轉移的二維dp丟到座標系裡面轉成路徑什麼的好像挺有用的
程式碼大概長這個樣子
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=5*(1e5)+10,MOD=998244353,inv2=499122177;
int fac[N*2],invfac[N*2];
int n,m;
int mul(int x,int y){return 1LL*x*y%MOD;}
int plu(int x,int y){return (1LL*x+y)%MOD;}
int C(int n,int m){return n<m?0:mul(fac[n],mul(invfac[m],invfac[n-m]));}
int calc(int n,int m){return C(n+m,m);}
int ksm(int x,int y){
int ret=1,base=x;
for (;y;y>>=1,base=mul(base,base))
if (y&1) ret=mul(ret,base);
return ret;
}
void prework(int n){
fac[0]=1;
for (int i=1;i<=n;++i) fac[i]=mul(fac[i-1],i);
invfac[n]=ksm(fac[n],MOD-2);
for (int i=n-1;i>=0;--i) invfac[i]=mul(invfac[i+1],i+1);
}
void solve(){
int ans=0;
for (int i=1;i<=min(n,m);++i)
ans=plu(ans,mul(calc(i,i),calc(n-i,m-i)));
ans=mul(ans,ksm(calc(n,m),MOD-2));
ans=mul(ans,inv2);
printf("%d\n",ans+max(n,m));
}
int main(){
#ifndef ONLINE_JUDGE
freopen("a.in","r",stdin);
#endif
scanf("%d%d",&n,&m);
prework(n+m);
solve();
}