Subway Lines(樹上兩條路的交點數)
阿新 • • 發佈:2018-12-22
題意: 給出一棵樹,n節點,每次詢問給兩對葉子,求這兩對葉子產生路徑的交集
解析:
找被走過兩次的點->走被走過兩次的所有Lca,Lca所構成的那一段長度就是點的數量
顯然,目標線段的端點一定是這些葉子節點的某個Lca
步驟:
- 找到所有Lca,放入set
- 統計哪些Lca被走過兩次
- 怎麼判斷走過幾次:一對葉子(x1,x2)與它的節點l1,假設某個點在x1到l1之間或x2到l2之間,則在路徑上
- 怎麼判斷在x1和l1之間:
Lca(x1,p)==p&&Lca(p,l1)==l1
- 對於走過兩次的點,求出組成路徑的長度
- 怎麼求路徑:首先是所有點的Lca,作為Root,刪除除了最底下點以外的所有點
- 怎麼刪除:
if(Lca(p1,p2)==p2)Erase(p2)
#include<bits/stdc++.h>
using namespace std;
#define maxn 500005
int head[maxn<<1],fa[maxn][25],vis[maxn],cnt,dep[maxn],dis[maxn];
struct node
{
int to,next,wei;
}e[maxn<<1];
void add(int x,int y)
{
e[cnt].to=y;
e[cnt].next=head[x];
head[ x]=cnt++;
}
void bfs()
{
fa[1][0]=1;
dep[1]=0;
dis[1]=0;
queue<int>Q;
Q.push(1);
while(!Q.empty())
{
int u,v;
u=Q.front();
Q.pop();
for(int i=1;i<=16;i++)
fa[u][i]=fa[fa[u][i-1]][i-1];
for(int i=head[u];~i;i=e[i].next)
{
v=e[i].to;
if(v==fa[u][0])
continue;
dis[v]=dis[u]+e[i].wei;
dep[v]=dep[u]+1;
fa[v][0]=u;
Q.push(v);
}
}
}
int lca(int x,int y){
if(dep[x]<dep[y])swap(x,y);
for(int i=16;i>=0;i--)if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
if(x==y)return x;
for(int i=16;i>=0;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
return fa[x][0];
}
set<int>st;
map<int,int>mp;
vector<int>fin;
int main()
{
int n,q;
scanf("%d%d",&n,&q);
int l,r;
memset(head,-1,sizeof(head));
for(int i=1;i<n;i++)
scanf("%d%d",&l,&r),add(l,r),add(r,l);
bfs();
while(q--)
{
st.clear();
fin.clear();
mp.clear();
int a,b,c,d;
scanf("%d%d%d%d",&a,&b,&c,&d);
st.insert(lca(a,b)),st.insert(lca(a,c)),st.insert(lca(a,d));
st.insert(lca(b,c)),st.insert(lca(b,d)); st.insert(lca(c,d));
int l1=lca(a,b),l2=lca(c,d);
for(set<int>::iterator it=st.begin();it!=st.end();it++)
{
int p=*it;
if((lca(a,p)==p&&lca(p,l1)==l1)||(lca(b,p)==p&&lca(p,l1)==l1))
mp[p]++;
}
for(set<int>::iterator it=st.begin();it!=st.end();it++)
{
int p=*it;
if((lca(c,p)==p&&lca(p,l2)==l2)||(lca(d,p)==p&&lca(p,l2)==l2))
mp[p]++;
}
int ans=0;
for(map<int,int>::iterator it=mp.begin();it!=mp.end();it++)
{
if(it->second>=2)
fin.push_back(it->first);
}
if(fin.size()==0)
{
printf("0\n");
continue;
}
if(fin.size()==1)
{
printf("1\n");
continue;
}
int l3,deep=1e9;
for(int i=0;i<fin.size();i++)
{
if(dep[fin[i]]<deep)
l3=fin[i],deep=dep[fin[i]];
}
while(fin.size()>2)
{
for(int i=0;i<fin.size()-1;i++)
{
for(int j=1;j<fin.size();j++)
{
if(lca(fin[i],fin[j])==fin[i])
fin.erase(fin.begin()+i);
else if(lca(fin[i],fin[j])==fin[j])
fin.erase(fin.begin()+j);
}
}
}
if(lca(fin[0],fin[1])==fin[0])
fin.erase(fin.begin());
else if(lca(fin[0],fin[1])==fin[1])
fin.erase(fin.begin()+1);
for(int i=0;i<fin.size();i++)
ans+=(dep[fin[i]]-dep[l3]);
ans++;
printf("%d\n",ans);
}
return 0;
}