51nod2626-未來常數【樹上啟發式合併,線段樹】
阿新 • • 發佈:2021-09-28
正題
題目連結:http://www.51nod.com/Challenge/Problem.html#problemId=2626
題目大意
給出\(n\)個點的一棵樹,每個區間\([l,r]\)的代價是選出這個區間中的一個點\(x\)使得它走到所有點然後又回到\(x\)的路程最短長度,求一個隨機區間的期望代價。
\(1\leq n\leq 10^5\)
解題思路
考慮統計每條邊的貢獻,一條邊會被記入當且僅當分成的兩個樹各存在一個點在區間中。
考慮怎麼統計這個貢獻,計在兩棵樹中的點分別為\(0\)和\(1\),那麼合法區間就是包含至少一個\(1\)和一個\(0\)的區間,用線段樹統計只包含\(0\)或\(1\)
然後在樹上的問題,所以直接上dsu on tree就好了。
時間複雜度:\(O(n\log^2n)\)
然後寫完題解突然發現線段樹合併好像也行而且更快(?
code
#include<cstdio> #include<cstring> #include<algorithm> #define ll long long using namespace std; const ll N=1e5+10,P=1e9+7; struct node{ ll to,next; }a[N<<1]; ll n,tot,ans,ls[N],siz[N],son[N]; ll w[N<<2],l0[N<<2],r0[N<<2],l1[N<<2],r1[N<<2]; void Merge(ll x,ll L,ll R){ ll mid=(L+R)>>1; w[x]=w[x*2]+w[x*2+1]+r0[x*2]*l0[x*2+1]+r1[x*2]*l1[x*2+1]; l0[x]=(l0[x*2]==mid-L+1)*l0[x*2+1]+l0[x*2]; r0[x]=(r0[x*2+1]==R-mid)*r0[x*2]+r0[x*2+1]; l1[x]=(l1[x*2]==mid-L+1)*l1[x*2+1]+l1[x*2]; r1[x]=(r1[x*2+1]==R-mid)*r1[x*2]+r1[x*2+1]; return; } void Build(ll x,ll L,ll R){ if(L==R){w[x]=r0[x]=l0[x]=1;return;} ll mid=(L+R)>>1; Build(x*2,L,mid); Build(x*2+1,mid+1,R); Merge(x,L,R); return; } void Change(ll x,ll L,ll R,ll pos){ if(L==R){swap(l0[x],l1[x]);swap(r0[x],r1[x]);return;} ll mid=(L+R)>>1; if(pos<=mid)Change(x*2,L,mid,pos); else Change(x*2+1,mid+1,R,pos); Merge(x,L,R);return; } void addl(ll x,ll y){ a[++tot].to=y; a[tot].next=ls[x]; ls[x]=tot;return; } void dfs(ll x,ll fa){ siz[x]=1; for(ll i=ls[x];i;i=a[i].next){ ll y=a[i].to; if(y==fa)continue; dfs(y,x);siz[x]+=siz[y]; if(siz[y]>siz[son[x]])son[x]=y; } return; } void calc(ll x,ll fa){ Change(1,1,n,x); for(ll i=ls[x];i;i=a[i].next){ ll y=a[i].to; if(y==fa)continue; calc(y,x); } return; } void solve(ll x,ll fa,ll top){ for(ll i=ls[x];i;i=a[i].next){ ll y=a[i].to; if(y==fa||y==son[x])continue; solve(y,x,y); } if(son[x])solve(son[x],x,top); Change(1,1,n,x); for(ll i=ls[x];i;i=a[i].next){ ll y=a[i].to; if(y==fa||y==son[x])continue; calc(y,x); } (ans+=(n*(n+1)/2-w[1])%P)%=P; if(x==top)calc(x,fa); return; } ll power(ll x,ll b){ ll ans=1; while(b){ if(b&1)ans=ans*x%P; x=x*x%P;b>>=1; } return ans; } signed main() { scanf("%lld",&n); for(ll i=1;i<n;i++){ ll x,y; scanf("%lld%lld",&x,&y); addl(x,y);addl(y,x); } Build(1,1,n); dfs(1,1); solve(1,1,0); printf("%lld\n",ans*2*power(n*(n+1)/2%P,P-2)%P); return 0; }