1. 程式人生 > 其它 >51nod2626-未來常數【樹上啟發式合併,線段樹】

51nod2626-未來常數【樹上啟發式合併,線段樹】

正題

題目連結:http://www.51nod.com/Challenge/Problem.html#problemId=2626


題目大意

給出\(n\)個點的一棵樹,每個區間\([l,r]\)的代價是選出這個區間中的一個點\(x\)使得它走到所有點然後又回到\(x\)的路程最短長度,求一個隨機區間的期望代價。

\(1\leq n\leq 10^5\)


解題思路

考慮統計每條邊的貢獻,一條邊會被記入當且僅當分成的兩個樹各存在一個點在區間中。

考慮怎麼統計這個貢獻,計在兩棵樹中的點分別為\(0\)\(1\),那麼合法區間就是包含至少一個\(1\)和一個\(0\)的區間,用線段樹統計只包含\(0\)\(1\)

的區間減去即可。

然後在樹上的問題,所以直接上dsu on tree就好了。

時間複雜度:\(O(n\log^2n)\)

然後寫完題解突然發現線段樹合併好像也行而且更快(?


code

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
using namespace std;
const ll N=1e5+10,P=1e9+7;
struct node{
	ll to,next;
}a[N<<1];
ll n,tot,ans,ls[N],siz[N],son[N];
ll w[N<<2],l0[N<<2],r0[N<<2],l1[N<<2],r1[N<<2];
void Merge(ll x,ll L,ll R){
	ll mid=(L+R)>>1;
	w[x]=w[x*2]+w[x*2+1]+r0[x*2]*l0[x*2+1]+r1[x*2]*l1[x*2+1];
	l0[x]=(l0[x*2]==mid-L+1)*l0[x*2+1]+l0[x*2];
	r0[x]=(r0[x*2+1]==R-mid)*r0[x*2]+r0[x*2+1];
	l1[x]=(l1[x*2]==mid-L+1)*l1[x*2+1]+l1[x*2];
	r1[x]=(r1[x*2+1]==R-mid)*r1[x*2]+r1[x*2+1];
	return;
}
void Build(ll x,ll L,ll R){
	if(L==R){w[x]=r0[x]=l0[x]=1;return;}
	ll mid=(L+R)>>1;
	Build(x*2,L,mid);
	Build(x*2+1,mid+1,R);
	Merge(x,L,R);
	return;
}
void Change(ll x,ll L,ll R,ll pos){
	if(L==R){swap(l0[x],l1[x]);swap(r0[x],r1[x]);return;}
	ll mid=(L+R)>>1;
	if(pos<=mid)Change(x*2,L,mid,pos);
	else Change(x*2+1,mid+1,R,pos);
	Merge(x,L,R);return;
}
void addl(ll x,ll y){
	a[++tot].to=y;
	a[tot].next=ls[x];
	ls[x]=tot;return;
}
void dfs(ll x,ll fa){
	siz[x]=1;
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa)continue;
		dfs(y,x);siz[x]+=siz[y];
		if(siz[y]>siz[son[x]])son[x]=y;
	}
	return;
}
void calc(ll x,ll fa){
	Change(1,1,n,x);
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa)continue;
		calc(y,x);
	}
	return;
}
void solve(ll x,ll fa,ll top){
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa||y==son[x])continue;
		solve(y,x,y);
	}
	if(son[x])solve(son[x],x,top);
	Change(1,1,n,x);
	for(ll i=ls[x];i;i=a[i].next){
		ll y=a[i].to;
		if(y==fa||y==son[x])continue;
		calc(y,x);
	}
	(ans+=(n*(n+1)/2-w[1])%P)%=P;
	if(x==top)calc(x,fa);
	return;
}
ll power(ll x,ll b){
	ll ans=1;
	while(b){
		if(b&1)ans=ans*x%P;
		x=x*x%P;b>>=1;
	}
	return ans;
}
signed main()
{
	scanf("%lld",&n);
	for(ll i=1;i<n;i++){
		ll x,y;
		scanf("%lld%lld",&x,&y);
		addl(x,y);addl(y,x);
	}
	Build(1,1,n);
	dfs(1,1);
	solve(1,1,0);
	printf("%lld\n",ans*2*power(n*(n+1)/2%P,P-2)%P);
	return 0;
}