LOJ 6267. 生成隨機數 題解
好題吼!
題意
題解
令 \(m=\sum a_i\) 。
首先肯定會想到建一棵二叉樹,葉節點數量 \(\ge m\) ,葉節點按 \(a_i\) 等比例分配。為了減少步數,把相同的狀態的葉節點放在一起,若一個子樹內葉節點相同,顯然到子樹根就可以停止,先不考慮返回葉節點的貢獻:
\(\displaystyle\sum_{i=1}^n \sum_{j=0}^{\infty} [a_i\&2^j]2^{j-k}(k-j)\)
解釋:把每個 \(a_i\) 二進位制下的 \(1\) 放在一個子樹中,到了子樹根就返回,那麼該子樹根的深度為 \(k-j\) (根為 \(0\)
但是可能 \(\sum a_i\) 不是 \(2\) 的次冪,必須建返回節點!此時不知道選哪個 \(k\) 最優!
那麼我們考慮能不能用無限層 的二叉樹,把 \(\frac{a_i}{m}\) 都表示出來,這樣我們就不用建返回節點了!
把 \(\frac{a_i}{m}\) 進行二進位制表示,第 \(2^{-i}\) 位的 \(1\) 對應到樹上的貢獻為 \(2^{-i} \cdot i\) ,而 \(\frac{a_i}{m}\) 一定能表示為二進位制下的 迴圈小數 !
\(\text{why?}\)
考慮大除法,只不過每次 \(\times 10\)
\(b_{i,j}\) 表示把 \(\frac{a_i}{m}\) 寫成二進位制小數後, \(2^{-j}\) 位是否為 \(1\) ,貢獻就是 \(\displaystyle E(\frac{a_i}{m})=\sum_{j=0}^{\infty} b_{i,j}\times j \times (\frac{1}{2})^j\) 。
然後可以證明, \(E(X)\) 是收斂的!**具體證明請看王總部落格 **(說明迴圈小數僅是為了說明貢獻收斂) 。
答案就是 \(\sum_{i=1}^n E(\frac{a_i}{m})\)
考慮找到 \(E(X)\) 與 \(E(\{2X\})\) 之間的關係,其中 \(\{\}\) 表取小數部分
\({}\)
所以 \(E(\{2X\})=2E(X)-2X\) 。 \(E(\frac{i}{m})\) 向 \(E(\frac{2i}{m})\) 連邊,形成基環樹森林,然後在環的地方解方程,求出所有的 \(E(\frac{i}{m})\) 即可。
Code
#include<bits/stdc++.h>
#define ri register int
#define ll long long
using namespace std;
const int maxn = 1e6 + 10,mod = 998244353;
template<class T>
inline void rd(T &x){
x = 0; char ch = getchar();
while(!isdigit(ch)) ch = getchar();
while(isdigit(ch)) x = x * 10 + ch - 48,ch = getchar();
}
int vis[maxn*10];
int n,a[maxn],m,rt;
ll f[maxn*10],g[maxn*10],ans[maxn*10],invm;
inline ll qp(ll x,ll k,ll res = 1){
x %= mod;
for(;k;k >>= 1,x = x * x % mod) if(k & 1) res = res * x % mod;
return res;
}
void dfs(int u){
vis[u] = 1;
int nxt = (u<<1) % m; ll nf = f[u] * 2 % mod,ng = g[u] * 2 - u * 2;
ng = (ng % mod + mod) % mod;
if(vis[nxt] == 1){
ans[rt] = (g[nxt] - ng + mod) * invm % mod * qp(nf - f[nxt] + mod,mod - 2) % mod;
return;
}
if(vis[nxt] == 2){
ans[rt] = (ans[nxt] - ng * invm % mod + mod) * qp(nf,mod - 2) % mod;
return;
}
g[nxt] = ng,f[nxt] = nf,dfs(nxt);
}
void calc(int u){
vis[u] = 2;
int nxt = (u<<1) % m;
if(vis[nxt] == 2) return;
ans[nxt] = ans[u] * 2 - 2 * u * qp(m,mod - 2) % mod;
ans[nxt] = (ans[nxt] % mod + mod) % mod;
calc(nxt);
}
int main(){
rd(n); for(ri i = 1;i <= n;++i) rd(a[i]),m += a[i];
invm = qp(m,mod - 2);
for(ri i = 1;i <= n;++i)
if(!vis[a[i]]) rt = a[i],f[rt] = 1,g[rt] = 0,dfs(rt),calc(rt);
ll res = 0;
for(ri i = 1;i <= n;++i){
res += ans[a[i]];
if(res >= mod) res -= mod;
}
printf("%lld\n",res);
return 0;
}