1. 程式人生 > 實用技巧 >Solution -「CF 1060F」Shrinking Tree

Solution -「CF 1060F」Shrinking Tree

\(\mathcal{Description}\)

  Link.

  給定一棵 \(n\) 個點的樹,反覆隨機選取一條邊,合併其兩端兩點,新點編號在兩端兩點等概率選取。問每個點留到最後的概率。

  \(n\le50\)

\(\mathcal{Solution}\)

  推薦 @ywy_c_asm部落格 owo。

  所有的操作方案數是 \((n-1)!\),我們可以按刪邊順序看做一個長度為 \(n-1\) 的序列。對於每個點分別計算答案,把當前要算的點提為根(記為 \(r\)),我們只需要求出 \(r\) 在所有操作序列中存活的概率和除以 \((n-1)!\) 即可。

  令 \(f(u,i)\)

\(r\) 已經走到 \(u\)\(u\) 子樹內還剩下 \(i\) 條邊沒刪(沒加入刪邊序列),最終 \(u\)(即 \(r\))存活的概率和。顯然答案為 \(f(r,n-1)\),邊界 \(f(leaf,0)=1\)


  第一步,考慮兒子 \(v\)\(u\) 合併。相當於需要考慮邊 \((u,v)\) 在操作序列中的位置。粗略來說,若 \(r\) 沒走到 \(u\),我們並不關心 \(u\) 號結點的生死;而 \(r\)\(u\) 後,\(u\) 就必須存活。

  定義輔助狀態 \(g(u,i)\) 表示 \(u\) 子樹內以及 \(u\) 的父邊還剩下 \(i\)

條邊沒刪,最終 \(u\) 存活的概率和。現在我們要計算 \(g(u,i)\)

  第一類,\((u,v)\) 保留到 \(r\) 到達 \(u\) 後再刪,那麼就涉及到 \(u\) 點存活的概率。於是有轉移:

\[g(u,i)=\frac{1}2\sum_{v\in son_u\land j\in[0,i)}f(v,j) \]

  第二類,\((u,v)\)\(r\) 到達 \(u\) 之前就刪,那就很隨意啦—— \(v\) 子樹中已刪除了 \(siz_v-1-i\) 條邊,我們把 \((u,v)\) 隨便插進一個位置就好,即:

\[g(u,i)=(siz_v-i)f(v,i) \]

  上兩類轉移貢獻之和即為最終的 \(g(u,i)\)


  考慮合併,始終記住刪除的“序列意義”——保留的邊(狀態第二維)在刪邊序列的右端,其它的邊在刪邊序列的左端。合併兩個刪邊序列,仍需要保證這一點,那麼分別用組合數合併已刪除的左端序列和待刪除的右端序列即可。下是 @ywy_c_asm 部落格的一張圖 owo(紅色已刪除,藍色待刪除):

  答案呼之欲出啦:

\[f'(u,i+j)=\sum_{i,j}\binom{i+j}{i}\binom{(siz_u-1)+siz_v-i-j}{(siz_u-1)-i}f(u,i)g(v,j) \]

  兩個組合數分別對應分配已刪和待刪的方案數。

  最終,複雜度 \(\mathcal O(n^4)\) 解決了這道毒瘤 DP qwq。

\(\mathcal{Code}\)

#include <cstdio>
#include <cstring>

const int MAXN = 50;
int n, ecnt, head[MAXN + 5], siz[MAXN + 5];
double fac[MAXN + 5];
double f[MAXN + 5][MAXN + 5];
double g[MAXN + 5], h[MAXN + 5];

struct Edge { int to, nxt; } graph[MAXN * 2 + 5];

inline void link ( const int s, const int t ) {
	graph[++ ecnt] = { t, head[s] };
	head[s] = ecnt;
}

inline void init () {
	fac[0] = 1;
	for ( int i = 1; i <= n; ++ i ) fac[i] = fac[i - 1] * i;
}

inline double comb ( const int n, const int m ) {
	return n < m ? 0 : fac[n] / fac[m] / fac[n - m];
}

inline void solve ( const int u, const int fa ) {
	f[u][0] = siz[u] = 1;
	for ( int i = head[u], v; i; i = graph[i].nxt ) {
		if ( ( v = graph[i].to ) ^ fa ) {
			solve ( v, u );
			for ( int j = 0; j <= siz[v]; ++ j ) {
				g[j] = 0;
				for ( int k = 0; k < j; ++ k ) g[j] += 0.5 * f[v][k];
				g[j] += ( siz[v] - j ) * f[v][j];
			}
			for ( int j = 0; j <= siz[v] + siz[u]; ++ j ) h[j] = 0;
			for ( int j = 0; j < siz[u]; ++ j ) {
				for ( int k = 0; k <= siz[v]; ++ k ) {
					h[j + k] += f[u][j] * g[k] * comb ( j + k, j )
						* comb ( siz[u] + siz[v] - 1 - j - k, siz[u] - 1 - j );
				}
			}
			siz[u] += siz[v];
			for ( int j = 0; j <= siz[u]; ++ j ) f[u][j] = h[j];
		}
	}
}

int main () {
	scanf ( "%d", &n ), init ();
	for ( int i = 1, u, v; i < n; ++ i ) {
		scanf ( "%d %d", &u, &v );
		link ( u, v ), link ( v, u );
	}
	for ( int i = 1; i <= n; ++ i ) {
		memset ( f, 0, sizeof f );
		solve ( i, 0 );
		printf ( "%.12f\n", f[i][n - 1] / fac[n - 1] );
	}
	return 0;
}