1. 程式人生 > 其它 >LOJ 6267. 生成隨機數 題解

LOJ 6267. 生成隨機數 題解

好題吼!

題意

https://loj.ac/p/6267

題解

\(m=\sum a_i\)

首先肯定會想到建一棵二叉樹,葉節點數量 \(\ge m\) ,葉節點按 \(a_i\) 等比例分配。為了減少步數,把相同的狀態的葉節點放在一起,若一個子樹內葉節點相同,顯然到子樹根就可以停止,先不考慮返回葉節點的貢獻

\(\displaystyle\sum_{i=1}^n \sum_{j=0}^{\infty} [a_i\&2^j]2^{j-k}(k-j)\)

解釋:把每個 \(a_i\) 二進位制下的 \(1\) 放在一個子樹中,到了子樹根就返回,那麼該子樹根的深度為 \(k-j\) (根為 \(0\)

) ,走到子樹中的概率為 \(2^{j-k}\)

但是可能 \(\sum a_i\) 不是 \(2\) 的次冪,必須建返回節點!此時不知道選哪個 \(k\) 最優!

那麼我們考慮能不能用無限層 的二叉樹,把 \(\frac{a_i}{m}\) 都表示出來,這樣我們就不用建返回節點了

\(\frac{a_i}{m}\) 進行二進位制表示,第 \(2^{-i}\) 位的 \(1\) 對應到樹上的貢獻為 \(2^{-i} \cdot i\) ,而 \(\frac{a_i}{m}\) 一定能表示為二進位制下的 迴圈小數 !

\(\text{why?}\)

考慮大除法,只不過每次 \(\times 10\)

變為 \(\times 2\) ,此時對於確定的被除數,下一步的餘數也是定值,因為 \(m\) 有限,所以一定會形成環結構。

\(b_{i,j}\) 表示把 \(\frac{a_i}{m}\) 寫成二進位制小數後, \(2^{-j}\) 位是否為 \(1\) ,貢獻就是 \(\displaystyle E(\frac{a_i}{m})=\sum_{j=0}^{\infty} b_{i,j}\times j \times (\frac{1}{2})^j\)

然後可以證明, \(E(X)\) 是收斂的!**具體證明請看王總部落格 **(說明迴圈小數僅是為了說明貢獻收斂) 。

答案就是 \(\sum_{i=1}^n E(\frac{a_i}{m})\)

考慮找到 \(E(X)\)\(E(\{2X\})\) 之間的關係,其中 \(\{\}\) 表取小數部分
\({}\)

\[E(\{2X\})= \begin{cases} \displaystyle\sum_{j=0}^{\infty} b_{i,j+1}\times j \times (\frac{1}{2})^j=\sum_{j=0}^{\infty} b_{i,j}\times (j-1) \times (\frac{1}{2})^{j-1}=2E(X)-2X\ (X < \frac{1}{2}) \\ \displaystyle\sum_{j=0}^{\infty} b_{i,j+1}\times j \times (\frac{1}{2})^j-1=2E(X)-2\{X\}-1=2E(X)-2X\ (X \ge \frac{1}{2}) \\ \end{cases} \]

所以 \(E(\{2X\})=2E(X)-2X\)\(E(\frac{i}{m})\)\(E(\frac{2i}{m})\) 連邊,形成基環樹森林,然後在環的地方解方程,求出所有的 \(E(\frac{i}{m})\) 即可。

Code

#include<bits/stdc++.h>
#define ri register int
#define ll long long
using namespace std;
const int maxn = 1e6 + 10,mod = 998244353;
template<class T>
inline void rd(T &x){
    x = 0; char ch = getchar();
    while(!isdigit(ch)) ch = getchar();
    while(isdigit(ch)) x = x * 10 + ch - 48,ch = getchar();
}
int vis[maxn*10];
int n,a[maxn],m,rt;
ll f[maxn*10],g[maxn*10],ans[maxn*10],invm;
inline ll qp(ll x,ll k,ll res = 1){
	x %= mod;
	for(;k;k >>= 1,x = x * x % mod) if(k & 1) res = res * x % mod;
	return res;
}
void dfs(int u){
	vis[u] = 1;
	int nxt = (u<<1) % m; ll nf = f[u] * 2 % mod,ng = g[u] * 2 - u * 2;
	ng = (ng % mod + mod) % mod;
	if(vis[nxt] == 1){
		ans[rt] = (g[nxt] - ng + mod) * invm % mod * qp(nf - f[nxt] + mod,mod - 2) % mod;
		return;
	}
	if(vis[nxt] == 2){
		ans[rt] = (ans[nxt] - ng * invm % mod + mod) * qp(nf,mod - 2) % mod;
		return;
	}
	g[nxt] = ng,f[nxt] = nf,dfs(nxt);
}
void calc(int u){
	vis[u] = 2;
	int nxt = (u<<1) % m;
	if(vis[nxt] == 2) return;
	ans[nxt] = ans[u] * 2 - 2 * u * qp(m,mod - 2) % mod;
	ans[nxt] = (ans[nxt] % mod + mod) % mod;
	calc(nxt);
}
int main(){
	rd(n); for(ri i = 1;i <= n;++i) rd(a[i]),m += a[i];
	invm = qp(m,mod - 2);
	for(ri i = 1;i <= n;++i)
		if(!vis[a[i]]) rt = a[i],f[rt] = 1,g[rt] = 0,dfs(rt),calc(rt);
	ll res = 0;
	for(ri i = 1;i <= n;++i){
		res += ans[a[i]];
		if(res >= mod) res -= mod;
	}
	printf("%lld\n",res);
	return 0;
}