【JZOJ A組】naive 的圖
阿新 • • 發佈:2018-11-01
Description
眾所周知,小 naive 有一張 n 個點,m 條邊的帶權無向圖。第 i 個點的顏色為 ci。d(s, t)表示從點 s 到點 t 的權值最小的路徑的權值,一條路徑的權值定義為路徑上權值最大的邊的權值。
求所有滿足 u < v, |cu − cv| ≥ L 的點對 (u, v) 的 d(u, v) 之和。
Input
輸入檔案為 graph.in。
第一行,三個整數 n, m, L,表示點數,邊數和引數 L。
第二行,n 個整數,第 i 個數為第 i 個點的顏色 ci。接下來 m 行,每行三個整數 ui, vi, wi,描述了一條邊。
Output
輸出檔案為 graph.out。
共一行,一個整數,表示答案。
Sample Input
4 5 2
6 4 5 2
1 2 8
2 3 7
2 4 8
1 3 2
1 4 1
Sample Output
17
樣例解釋
滿足條件的點對:(1, 2),(1, 4),(2, 4),(3, 4),答案為 7 + 1 + 7 + 2 = 17。
Data Constraint
對於每個測試點內的資料:
思路
先建出 Kruskal 重構樹,每條邊的貢獻次數為它連線的兩個子樹之間的顏色之差大於等於 L 的點對數,可以發現 ∑ min(size(lef tchildi), size(rightchildi)) = O(nlog2n)。
對於每條邊我們列舉 size 較小的那棵子樹內的點,算出在另一棵子樹中能與它組成點對的點的個數。這個問題實際上就是詢問在 dfs 序的一段區間上並且顏色不在一段區間內的點數,二維數點問題可以離線樹狀陣列完成。
總的時間複雜度為 O(mlog2m + nlog2n)
程式碼
#include<cstdio> #include<iostream> #include<algorithm> #define ll long long using namespace std; const int inf=0x3f3f3f3f,N=2e5+77,M=5e5+77; ll n,m,l,tot,num,sum,ans,c[N],f[N],d[N],list[N]; struct E { ll x,y,z; }a[M]; struct node { ll to,next; }e[N]; struct tr{ ll l,r,v; }tr[100*N]; bool cmp(E a,E b) { return a.z<b.z; } ll gf(ll x) { return f[x]==x?x:f[x]=gf(f[x]); } void add(ll x,ll y) { e[++tot].to=y; e[tot].next=list[x]; list[x]=tot; } ll query(ll d,ll l,ll r,ll st,ll ed) { if(st>ed) return 0; if(l==st&&r==ed) return tr[d].v; ll mid=(l+r)>>1; if(ed<=mid) return query(tr[d].l,l,mid,st,ed); else if(st>mid) return query(tr[d].r,mid+1,r,st,ed); else return query(tr[d].l,l,mid,st,mid)+query(tr[d].r,mid+1,r,mid+1,ed); } void dfs(ll d,ll x,ll fa) { sum+=query(d,0,inf,0,c[x]-l)+query(d,0,inf,c[x]+l,inf); for(ll i=list[x]; i; i=e[i].next) if(e[i].to!=fa) dfs(d,e[i].to,x); } void ins(ll d,ll l,ll r,ll x) { if(l==r) { tr[d].v++; return; } ll mid=(l+r)>>1; if(x<=mid) { if(!tr[d].l) tr[d].l=++num; ins(tr[d].l,l,mid,x); } else { if(!tr[d].r) tr[d].r=++num; ins(tr[d].r,mid+1,r,x); } tr[d].v=tr[tr[d].l].v+tr[tr[d].r].v; } void merge(ll l,ll r,ll st,ll ed) { if(st==ed) { tr[r].v+=tr[l].v; return; } ll mid=(st+ed)>>1; if(tr[l].l&&tr[r].l) merge(tr[l].l,tr[r].l,st,mid); else if(tr[l].l) tr[r].l=tr[l].l; if(tr[l].r&&tr[r].r) merge(tr[l].r,tr[r].r,mid+1,ed); else if(tr[l].r) tr[r].r=tr[l].r; tr[r].v=tr[tr[r].l].v+tr[tr[r].r].v; } int main() { freopen("graph.in","r",stdin),freopen("graph.out","w",stdout); scanf("%lld%lld%lld",&n,&m,&l),num=n; for(ll i=1; i<=n; i++) scanf("%lld",&c[i]),f[i]=i,d[i]=1,ins(i,0,inf,c[i]); for(ll i=1; i<=m; i++) scanf("%lld%lld%lld",&a[i].x,&a[i].y,&a[i].z); sort(a+1,a+m+1,cmp); ll i=1,tot=0; while(tot<n-1) { ll u=gf(a[i].x),v=gf(a[i].y); if(u!=v) { if(d[u]>d[v]) swap(a[i].x,a[i].y),swap(u,v); if(!l) sum=d[u]*d[v]; else sum=0,dfs(v,u,0); ans+=sum*a[i].z,add(v,u),merge(u,v,0,inf); f[u]=v,d[v]+=d[u];d[u]=0; ++tot; } i++; } printf("%lld\n",ans); }