1. 程式人生 > 其它 >Solution -「Gym 102979E」Expected Distance

Solution -「Gym 102979E」Expected Distance

\(\mathcal{Description}\)

  Link.

  用給定的 \(\{a_{n-1}\},\{c_n\}\) 生成一棵含有 \(n\) 個點的樹,其中 \(u\) 連向 \([1,u)\) 中的某個 \(v\),概率為 \(\frac{a_v}{a_1+a_2+\cdots+a_{u-1}}\),邊權為 \(c_u+c_v\)。並給出 \(q\) 組詢問 \((u_i,v_i)\),每次回答 \(u_i\)\(v_i\) 的樹上距離的期望。答案對 \((10^9+7)\) 取模。

  \(n,q\le3\times10^5\)

\(\mathcal{Solution}\)

\[\newcommand{\vct}[1]{\boldsymbol{#1}}\newcommand{\stir}[2]{\genfrac{\{}{\}}{0pt}{}{#1}{#2}}\newcommand{\opn}[1]{\operatorname{#1}}\newcommand{\lcm}[0]{\opn{lcm}}\newcommand{\sg}[0]{\opn{sg}}\newcommand{\dist}[0]{\opn{dist}}\newcommand{\lca}[0]{\opn{lca}}\newcommand{\floor}[2]{\left\lfloor\frac{#1}{#2}\right\rfloor}\newcommand{\ceil}[2]{\left\lceil\frac{#1}{#2}\right\rceil} \]

  問題卡殼,必有結論。

  令 \(1\) 為根,把 \(\dist(u,v)\) 轉化成 \(\dist(1,u)+\dist(1,v)-2\dist(\lca(u,v))\)。記 \(f(u)=E(\dist(1,u))\),顯然有

\[f(u)=c_u+\frac{1}{s_{u-1}}\sum_{v<u}a_v(f_v+c_v). \]

其中 \(s_i=\sum_{j=1}^ia_i\),可見 \(f\) 可以輕易地 \(\mathcal O(n)\) 求出。我們接下來研究 \(\dist(\lca(u,v))\)。不妨設 \(u<v\),可以發現一個結論:

\[\forall v>u,~E(\dist(\lca(u,v)))=g(u). \]

其中 \(g(u)\)

是僅與 \(u\) 有關的量。

證明   考慮求 $\lca(u,v)$ 的方式,在 $v$ 沿著祖先跳躍時,我們只關心第一次使得 $v\le u$ 的位置。此時僅有兩種情況
  • \(v=u\),概率為 \(\frac{a_u}{s_u}\)
  • \(v<u\),概率為 \(\frac{s_{u-1}}{s_u}\)

  可見與 \(v\) 無關。

  在證明的基礎上,亦能得到 \(g(u)\) 的轉移:

\[g(u)=\frac{1}{s_u}\left(a_uc_u+\sum_{v<u}a_vg_v\right). \]

也能 \(\mathcal O(n)\) 求出,所以本題就解決啦。

\(\mathcal{Code}\)

/*~Rainybunny~*/

#include <bits/stdc++.h>

#define rep( i, l, r ) for ( int i = l, rep##i = r; i <= rep##i; ++i )
#define per( i, r, l ) for ( int i = r, per##i = l; i >= per##i; --i )

const int MAXN = 3e5, MOD = 1e9 + 7;
int n, q, a[MAXN + 5], s[MAXN + 5], invs[MAXN + 5];
int c[MAXN + 5], f[MAXN + 5], g[MAXN + 5];

inline int mul( const int a, const int b ) { return 1ll * a * b % MOD; }
inline int sub( int a, const int b ) { return ( a -= b ) < 0 ? a + MOD : a; }
inline int add( int a, const int b ) { return ( a += b ) < MOD ? a : a - MOD; }
inline int mpow( int a, int b ) {
    int ret = 1;
    for ( ; b; a = mul( a, a ), b >>= 1 ) ret = mul( ret, b & 1 ? a : 1 );
    return ret;
}

int main() {
    std::ios::sync_with_stdio( false ), std::cin.tie( 0 );

    std::cin >> n >> q;
    rep ( i, 1, n - 1 ) {
        std::cin >> a[i], s[i] = a[i] + s[i - 1];
        invs[i] = mpow( s[i], MOD - 2 );
    }
    rep ( i, 1, n ) std::cin >> c[i];

    for ( int i = 2, pre = mul( a[1], c[1] ); i <= n; ++i ) {
        f[i] = add( c[i], mul( invs[i - 1], pre ) );
        pre = add( pre, mul( a[i], add( f[i], c[i] ) ) );
    }

    for ( int i = 2, pre = 0; i < n; ++i ) {
        g[i] = mul( invs[i], add( mul( a[i], f[i] ), pre ) );
        pre = add( pre, mul( a[i], g[i] ) );
    }

    for ( int u, v; q--; ) {
        std::cin >> u >> v;
        if ( u > v ) u ^= v ^= u ^= v;
        if ( u == v ) std::cout << "0\n";
        else std::cout << sub( add( f[u], f[v] ), mul( 2, g[u] ) ) << '\n';
    }
    return 0;
}