1. 程式人生 > >[SHOI2014] 概率充電器

[SHOI2014] 概率充電器

Description

給定一棵\(N(N\leq 5\times 10^5)\)個節點的樹。每個點有概率被直接充電,每條邊有概率導電。如果一個沒有被直接充電的點通過一條導電的邊連線到了某個被充電的點,那麼這個點也會被充電。問期望充電的點的個數。

Solution

由期望線性性,我們可以求出每個點被充電的概率,最後求和即可。

節點\(i\)通電有三種可能

  1. 它自己來電
  2. 它子樹裡有一個來電傳了過來
  3. 子樹外有一個來電傳了過來

可以兩遍\(dfs\)做,第一遍先求出第一種和第二種的概率,再換根\(dfs\)加上第三種的概率就行了。這是最開始的\(naive\)想法。

但是這些情況發生的概率不能直接加起來。考慮兩個事件\(A,B\)

,發生的概率分別是\(P(A),P(B)\),那麼至少發生一件的概率是\(P(A)+P(B)-P(A)*P(B)\)

證明就是列舉三種情況:

  1. \(A\)發生\(B\)不發生,概率為\(P(A)*(1-P(B))\)
  2. \(A\)不發生\(B\)發生,概率為\((1-P(A))*P(B)\)
  3. \(A,B\)都發生,概率為\(P(A)*P(B)\)

三種情況求和就是\(P(A)+P(B)-P(A)*P(B)\)了。

知道了這個做兩遍\(dfs\)就行了剛做過一個差不多的

啊Typora程式碼框裡的字型好好看啊誰知道這是什麼字型啊求告知QAQ

Code

#include<set>
#include<map>
#include<cmath>
#include<queue>
#include<cctype>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
using std::min;
using std::max;
using std::swap;
using std::vector;
const int N=500005;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)

db f[N],p[N];
int n,cnt,head[N];

struct Edge{
    int to,nxt;db dis;
}edge[N<<1];

void add(int x,int y,db z){
    edge[++cnt].to=y;
    edge[cnt].nxt=head[x];
    edge[cnt].dis=z;
    head[x]=cnt;
}

int getint(){
    int X=0,w=0;char ch=0;
    while(!isdigit(ch))w|=ch=='-',ch=getchar();
    while( isdigit(ch))X=X*10+ch-48,ch=getchar();
    if(w) return -X;return X;
}

void dfs(int now,int fa){
    f[now]=p[now];
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        if(to==fa) continue;
        dfs(to,now);
        f[now]=f[now]+f[to]*edge[i].dis-f[now]*f[to]*edge[i].dis;
    }
}

void dfs(int now,int fa,db dis){
    if(f[now]*dis!=1){
        db x=(f[fa]-f[now]*dis)/(1-f[now]*dis);
        f[now]=f[now]+x*dis-f[now]*x*dis;
    } else f[now]=1;
    for(int i=head[now];i;i=edge[i].nxt){
        int to=edge[i].to;
        if(to==fa) continue;
        dfs(to,now,edge[i].dis);
    }
}

signed main(){
    n=getint();
    for(int i=1;i<n;i++){
        int x=getint(),y=getint(),z=getint();
        add(x,y,0.01*z),add(y,x,0.01*z);
    }
    for(int i=1;i<=n;i++) p[i]=0.01*getint();
    dfs(1,0);
    for(int i=head[1];i;i=edge[i].nxt)
        dfs(edge[i].to,1,edge[i].dis);
    db ans=0;
    for(int i=1;i<=n;i++) ans+=f[i];
    printf("%.6lf\n",ans);
    return 0;
}