1. 程式人生 > >【POJ 1741】Tree

【POJ 1741】Tree

【題目】

傳送門

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001).

Define dist(u,v)=The min distance between node u and v.

Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k.

Write a program that will count how many pairs which are valid for a given tree.

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l.

The last test case is followed by two zeros.

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

【分析】

題目大意:(多組資料)給出一棵邊帶權樹,求出這棵樹中距離不超過 k k 的點對的數量

題解:點分治模板題

由於這是我的第一道點分治題,我還是好好寫一下部落格吧

先假設這是一道有根樹,那滿足條件的點對必然是以下兩種情況:

  1. 它們的路徑經過根節點
  2. 它們的路徑不經過根節點(也就是說它們在同一個子樹中)

對於 2,可以把它當成子問題,遞迴求解,現在就是討論如何求出 1

假設 d i s i dis_i i i 到根的路徑長度,用 d f s dfs 求出所有點到根的距離,然後對所有 d i s dis 排序,這樣就便於統計 d i s x + d i s y k dis_x+dis_y≤k 的總數,但這樣做我們用把 2 的部分情況考慮進去,還要減掉這些情況

怎麼選這個根呢,考慮用重心,因為減去重心後,子樹的 s i z e size 都會減少一半,這樣可以保證複雜度

遞迴層數 O( l o g &ThickSpace; n log\;n ), s o r t sort 是 O( n l o g &ThickSpace; n n * log\;n ),總複雜度是O( n l o g 2 &ThickSpace; n n*log^2\;n


【程式碼】

#include<cstdio>
#include<cstring>
#include<algorithm>
#define N 50005
#define inf (1ll<<31ll)-1
using namespace std;
int n,k,t,ans,num,root,sum;
int d[N],size[N],Max[N];
int first[N],v[N],w[N],next[N];
bool vis[N];
void add(int x,int y,int z)
{
	t++;
	next[t]=first[x];
	first[x]=t;
	v[t]=y;
	w[t]=z;
}
void dfs(int x,int father)
{
	int i,j;
	Max[x]=0;
	size[x]=1;
	for(i=first[x];i;i=next[i])
	{
		j=v[i];
		if(j!=father&&!vis[j])
		{
			dfs(j,x);
			size[x]+=size[j];
			Max[x]=max(Max[x],size[j]);
		}
	}
}
void find(int rt,int x,int father)
{
	int i,j;
	Max[x]=max(Max[x],size[rt]-size[x]);
	if(num>Max[x])  num=Max[x],root=x;
	for(i=first[x];i;i=next[i])
	{
		j=v[i];
		if(j!=father&&!vis[j])
		  find(rt,j,x);
	}
}
void dist(int x,int father,int len)
{
	int i,j;
	d[++sum]=len;
	for(i=first[x];i;i=next[i])
	{
		j=v[i];
		if(j!=father&&!vis[j])
		  dist(j,x,len+w[i]);
	}
}
int calc(int x,int l)
{
	sum=0,dist(x,0,l);
	sort(d+1,d+sum+1);
	int ans=0,i=1,j=sum;
	while(i<j)
	{
		while(d[i]+d[j]>k&&i<j)  j--;
		ans+=j-i;i++;
	}
	return ans;
}
void solve(int x)
{
	int i,j;
	dfs(x,0);
	num=inf,find(x,x,0);
	ans+=calc(root,0);
	vis[root]=true;
	for(i=first[root];i;i=next[i])
	{
		j=v[i];
		if(!vis[j])
		{
			ans-=calc(j,w[i]);
			solve(j);
		}
	}
}
int main()
{
	int x,y,z,i;
	while(~scanf("%d%d",&n,&k))
	{
		ans=0,t=0;
		if(!n&&!k)  break;
		memset(first,0,sizeof(first));
		memset(vis,false,sizeof(vis));
		for(i=1;i<n;++i)
		{
			scanf("%d%d%d",&x,&y,&z);
			add(x,y,z),add(y,x,z);
		}
		solve(1);
		printf("%d\n",ans);
	}
	return 0;
}