【POJ 1741】Tree
【題目】
Description
Give a tree with n vertices,each edge has a length(positive integer less than 1001).
Define dist(u,v)=The min distance between node u and v.
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.
Write a program that will count how many pairs which are valid for a given tree.
Input
The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.
The last test case is followed by two zeros.
Output
For each test case output the answer on a single line.
Sample Input
5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0
Sample Output
8
【分析】
題目大意:(多組資料)給出一棵邊帶權樹,求出這棵樹中距離不超過 的點對的數量
題解:點分治模板題
由於這是我的第一道點分治題,我還是好好寫一下部落格吧
先假設這是一道有根樹,那滿足條件的點對必然是以下兩種情況:
- 它們的路徑經過根節點
- 它們的路徑不經過根節點(也就是說它們在同一個子樹中)
對於 2,可以把它當成子問題,遞迴求解,現在就是討論如何求出 1
假設 為 到根的路徑長度,用 求出所有點到根的距離,然後對所有 排序,這樣就便於統計 的總數,但這樣做我們用把 2 的部分情況考慮進去,還要減掉這些情況
怎麼選這個根呢,考慮用重心,因為減去重心後,子樹的 都會減少一半,這樣可以保證複雜度
遞迴層數 O( ), 是 O( ),總複雜度是O( )
【程式碼】
#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 50005
#define inf (1ll<<31ll)-1
using namespace std;
int n,k,t,ans,num,root,sum;
int d[N],size[N],Max[N];
int first[N],v[N],w[N],next[N];
bool vis[N];
void add(int x,int y,int z)
{
t++;
next[t]=first[x];
first[x]=t;
v[t]=y;
w[t]=z;
}
void dfs(int x,int father)
{
int i,j;
Max[x]=0;
size[x]=1;
for(i=first[x];i;i=next[i])
{
j=v[i];
if(j!=father&&!vis[j])
{
dfs(j,x);
size[x]+=size[j];
Max[x]=max(Max[x],size[j]);
}
}
}
void find(int rt,int x,int father)
{
int i,j;
Max[x]=max(Max[x],size[rt]-size[x]);
if(num>Max[x]) num=Max[x],root=x;
for(i=first[x];i;i=next[i])
{
j=v[i];
if(j!=father&&!vis[j])
find(rt,j,x);
}
}
void dist(int x,int father,int len)
{
int i,j;
d[++sum]=len;
for(i=first[x];i;i=next[i])
{
j=v[i];
if(j!=father&&!vis[j])
dist(j,x,len+w[i]);
}
}
int calc(int x,int l)
{
sum=0,dist(x,0,l);
sort(d+1,d+sum+1);
int ans=0,i=1,j=sum;
while(i<j)
{
while(d[i]+d[j]>k&&i<j) j--;
ans+=j-i;i++;
}
return ans;
}
void solve(int x)
{
int i,j;
dfs(x,0);
num=inf,find(x,x,0);
ans+=calc(root,0);
vis[root]=true;
for(i=first[root];i;i=next[i])
{
j=v[i];
if(!vis[j])
{
ans-=calc(j,w[i]);
solve(j);
}
}
}
int main()
{
int x,y,z,i;
while(~scanf("%d%d",&n,&k))
{
ans=0,t=0;
if(!n&&!k) break;
memset(first,0,sizeof(first));
memset(vis,false,sizeof(vis));
for(i=1;i<n;++i)
{
scanf("%d%d%d",&x,&y,&z);
add(x,y,z),add(y,x,z);
}
solve(1);
printf("%d\n",ans);
}
return 0;
}