1. 程式人生 > 其它 >[數學記錄]CF1540B Tree Arrays

[數學記錄]CF1540B Tree Arrays

題意:

一棵樹上的點編號 \(1 \to n\),初始隨機一個點染色,每一步隨機一個與已染色點有連邊的點且未染色的點染色,求染色序列的點的編號序列的期望逆序對數。

\(n \leq 200\)

\(n\) 這麼小,所以不妨以此欽定根來做。

\(n\) 還是這麼小,不如列舉一對點 \(u,v\),計算其構成逆序對的概率。

容易發現,只用關心根到 \(u,v\) 的路程。其中也只用關心 lca 到兩邊的路程。

問題轉化為:給定兩個棧,每次隨機彈出一個棧的棧頂,求某個棧先被彈空的概率。

這是明顯的 dp 形式。設 \(dp_{i,j}\) 表示兩個棧大小分別為 \(i,j\)\(i\) 先空的概率,則 \(dp_{i,j} = \dfrac{1}{2}(dp_{i-1,j} + dp_{i,j-1})\)

結束了。複雜度 \(O(n^3 \log n)\),最後的 \(\log\) 是求 lca 的複雜度。

感覺一步步轉化其實都相當平凡,所以沒想到可能只是自己沒有用心去想。\(n\) 很小,這告訴我們可以去列舉相當多東西,減少相當多變數的,於是能想到直接把貢獻下放到點對上。

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
const int M = 205, mod = 1000000007, invt = (mod + 1) / 2;
int qpow(int a, int b){
    long long ans = 1ll;
    for(; b; b >>= 1) {if(b & 1) ans = 1ll * ans * a % mod; a = 1ll * a * a % mod;}
    return ans;
}
int add(int a, int b) {a += b; return a > mod ? a-mod : a;}
int minus(int a, int b) {a -= b; return a < 0 ? a+mod : a;}
void addn(int &x, int y) {x += y; if(x > mod) x -= mod;}
int f[M][M]; int dep[M];
struct edge{
    int to, nxt;
}e[M << 1];
int head[M], cnt1;
void link(int u, int v){
    e[++cnt1] = {v, head[u]}; head[u] = cnt1;
}
struct LCA{
    int f[M][15], d[M];
    LCA() {memset(f, 0, sizeof(f)); memset(d, 0, sizeof(d));}
    void dfs(int u, int fa){
        f[u][0] = fa; d[u] = d[fa] + 1;
        for(int i = 1; i <= 10; i++) f[u][i] = f[f[u][i - 1]][i - 1];
        for(int i = head[u]; i; i = e[i].nxt){
            int v = e[i].to; if(v == fa) continue; dfs(v, u);
        }
    }
    int lca(int u, int v){
        if(d[u] < d[v]) swap(u, v);
        for(int i = 10; i >= 0; i--) if(d[f[u][i]] >= d[v]) u = f[u][i];
        if(u == v) return u;
        for(int i = 10; i >= 0; i--) if(f[u][i] != f[v][i]) u = f[u][i], v = f[v][i];
        return u == v ? u : f[u][0];
    }
    void init(int rt) {
        memset(f, 0, sizeof(f)); memset(d, 0, sizeof(d));
        d[rt] = 0; dfs(rt, 0);
    }
    int dis(int u, int f) {return abs(d[u] - d[f]);}
}t;
int n;
int main(){
    scanf("%d", &n);
    for(int i = 1; i < n; i++){
        int u, v; scanf("%d %d", &u, &v);
        link(u, v); link(v, u);
    }
    for(int i = 1; i <= n; i++) f[0][i] = 1;
    for(int i = 1; i <= n; i++) 
        for(int j = 1; j <= n; j++) 
            f[i][j] = 1ll * add(f[i-1][j], f[i][j-1]) * invt % mod;
    // for(int i = 0; i <= n; i++) {
    //     for(int j = 0; j <= n; j++) printf("%d ", f[i][j]);
    //     printf("\n");
    // }
    int ans = 0;
    for(int i = 1; i <= n; i++) {
        t.init(i);
        for(int u = 1; u <= n; u++) {
            for(int v = 1; v < u; v++) {
                if(u == v) continue;
                int l = t.lca(u, v);
                addn(ans, f[t.dis(u, l)][t.dis(v, l)]);
            }
        }
    }
    printf("%d\n", 1ll * ans * qpow(n, mod-2) % mod);
}