1. 程式人生 > >poj 1741 Tree (點分治)

poj 1741 Tree (點分治)

Tree
Time Limit: 1000MS Memory Limit: 30000K
Total Submissions: 24493 Accepted: 8161

Description

Give a tree with n vertices,each edge has a length(positive integer less than 1001). 
Define dist(u,v)=The min distance between node u and v. 
Give an integer k,for every pair (u,v) of vertices is called valid if and only if dist(u,v) not exceed k. 
Write a program that will count how many pairs which are valid for a given tree. 

Input

The input contains several test cases. The first line of each test case contains two integers n, k. (n<=10000) The following n-1 lines each contains three integers u,v,l, which means there is an edge between node u and v of length l. 
The last test case is followed by two zeros. 

Output

For each test case output the answer on a single line.

Sample Input

5 4
1 2 3
1 3 1
1 4 2
3 5 1
0 0

Sample Output

8

前幾天做2016 icpc大連 題碰到一個要用到點分治的題目,當時用並不會樹dp碼了很久。

#include<stdio.h>
#include<string.h>
#include<algorithm>
using namespace std;
#define Del(a,b) memset(a,b,sizeof(a))
const int maxn=1e4+50;
struct Edge
{
	int u,v,w,next;
}edge[maxn<<1];
int cnt;
int head[maxn];
int sz[maxn],fz[maxn],vis[maxn];
int root,n,k,size;
void addedge(int u,int v,int w)
{
	edge[cnt].u=u;
	edge[cnt].v=v;
	edge[cnt].w=w;
	edge[cnt].next=head[u];
	head[u]=cnt++;
}
void getroot(int x,int pre)
{
	sz[x]=1; fz[x]=0;
	for(int i=head[x];i!=-1;i=edge[i].next)
	{
		if(edge[i].v==pre||vis[edge[i].v]) continue;
		getroot(edge[i].v,x);
		sz[x]+=sz[edge[i].v];
		fz[x]=max(fz[x],sz[edge[i].v]);
	}
	fz[x]=max(fz[x],size-sz[x]);
	if(fz[x]<fz[root]) root=x;
}
int dis[maxn],dsort[maxn];
int numdis;
void getdis(int x,int pre)
{
	dsort[numdis++]=dis[x];;
	for(int i=head[x];i!=-1;i=edge[i].next)
	{
		if(vis[edge[i].v]||edge[i].v==pre) continue;
		dis[edge[i].v]=dis[x]+edge[i].w;
		getdis(edge[i].v,x);
	}
}
int cal(int x,int d)
{
	int ans=0;
	dis[x]=d;
	numdis=0;
	getdis(x,-1);
	sort(dsort,dsort+numdis);
	int l=0,r=numdis-1;
	while(l<r)
	{
		while(l<r&&dsort[l]+dsort[r]>k) r--;
		ans+=r-l;
		l++;
	}
	return ans;
}

int anss;
void solve(int x)
{
	anss+=cal(x,0);
	vis[x]=1;
	for(int i=head[x];i!=-1;i=edge[i].next)
	{
		if(vis[edge[i].v])continue;
		anss-=cal(edge[i].v,edge[i].w);
		fz[0]=size=sz[edge[i].v];
		root=0;
		getroot(edge[i].v,-1);
		solve(root);
	}
}
void init()
{
	Del(head,-1);
	cnt=0;
	Del(vis,0);
}
int main()
{
	while(scanf("%d%d",&n,&k)==2)
	{
		if(n==0&&k==0) break;
		init();
		for(int i=0;i<n-1;i++)
		{
			int u,v,w;
			scanf("%d%d%d",&u,&v,&w);
			addedge(u,v,w);
			addedge(v,u,w);
		}
		anss=0;
		root=0;
		fz[root]=size=n;
		getroot(1,-1);
		solve(root);
		printf("%d\n",anss);
	}
}