1. 程式人生 > >1036 商務旅行 lca 離線

1036 商務旅行 lca 離線

題目描述 Description

某首都城市的商人要經常到各城鎮去做生意,他們按自己的路線去做,目的是為了更好的節約時間。

假設有N個城鎮,首都編號為1,商人從首都出發,其他各城鎮之間都有道路連線,任意兩個城鎮之間如果有直連道路,在他們之間行駛需要花費單位時間。該國公路網路發達,從首都出發能到達任意一個城鎮,並且公路網路不會存在環。

你的任務是幫助該商人計算一下他的最短旅行時間。

輸入描述 Input Description

輸入檔案中的第一行有一個整數N,1<=n<=30 000,為城鎮的數目。下面N-1行,每行由兩個整數a 和b (1<=ab<=n; a<>b)組成,表示城鎮a和城鎮b有公路連線。在第N+1行為一個整數M,下面的M行,每行有該商人需要順次經過的各城鎮編號。

輸出描述 Output Description

    在輸出檔案中輸出該商人旅行的最短時間。

樣例輸入 Sample Input

5
1 2
1 5
3 5
4 5
4
1
3
2
5

樣例輸出 Sample Output

7

#include<cstdio>
#include<cmath>
#include<iostream>
#include<cstring>
#include<algorithm>
#include<map>
#include<vector>
#define N 100005 
using namespace std;
int father[30010*2],vis[30010*2];
int head[30010*2],head1[30010*2],dis[30010*2];
int root[30010*2],ans[30010*2]; 
int z,zz;
int n,m;
struct ac
{
	int u,v,next,lca,w,id;
}r[30010*2],rr[30010*2];
void init()
{
	z=zz=0;
	memset(vis,0,sizeof(vis));
	memset(head,-1,sizeof(head));
	memset(head1,-1,sizeof(head1));
	memset(root,1,sizeof(root));
	memset(dis,0,sizeof(dis));
	memset(ans,0,sizeof(ans));
	for(int i=1;i<=n;i++)
	{
		father[i]=i;
	}
}
void add(int u,int v,int w)
{
	r[z].u=u;
	r[z].v=v;
	r[z].w=w;
	r[z].next=head[u];
	head[u]=z++;
}
void add1(int u,int v,int id)
{
	rr[zz].u=u;
	rr[zz].v=v;
	rr[zz].id=id;
	rr[zz].lca=-1;
	rr[zz].next=head1[u];
	head1[u]=zz++;
}
int find(int w)
{
	if(w==father[w])
	return w;
	return father[w]=find(father[w]);
}
void link(int x,int y)
{
	x=find(x);
	y=find(y);
	if(x!=y)
	father[y]=x;
}
void taxjan(int cur)
{
	vis[cur]=1;
	for(int i=head[cur];i+1;i=r[i].next)
	{
		int v=r[i].v;
		int w=r[i].w;
		if(!vis[v])
		{
			dis[v]=dis[cur]+w;
			taxjan(v);
			link(cur,v);
		}
	}
	for(int i=head1[cur];i+1;i=rr[i].next)
	{
		int v=rr[i].v;
		if(vis[v])
		{
			int z=find(v);
			ans[rr[i].id]=dis[cur]-2*dis[z]+dis[v];
		}
	}
	
}
int main()
{
		cin>>n;
		init();
		for(int i=1;i<n;i++)
		{
			int u,v;
			cin>>u>>v;
			root[v]=0;
			add(u,v,1);
			add(v,u,1);
		}
		int m,x;
		cin>>m;
		int c,d;
	   	cin>>c;
		for(int i=1;i<=m-1;i++){
		if(i%2!=0)cin>>d;
		else cin>>c;
		add1(c,d,i);
		add1(d,c,i);
		}
		taxjan(1);
		int sum=0;
		for(int i=1;i<=m-1;i++)
		{
			sum+=ans[i];
		}
		cout<<sum<<endl;

	return 0;
}