1. 程式人生 > >Gym - 101908L Subway Lines —— lca處理樹上兩個線段相交部分

Gym - 101908L Subway Lines —— lca處理樹上兩個線段相交部分

The subway system of a major city is formed by a set of stations and tunnels that connect some pairs of stations. The system was designed so that there is exactly one sequence of tunnels linking any pair of stations. The stations for which there is only one tunnel passing through are called terminals. There are several train lines that make round trips between pairs of terminal stations, passing only through the stations in the unique path between them. People are complaining about the current lines. Therefore, the mayor ordered that the lines are redefined from scratch. As the system has many stations, we need to help the engineers, who are trying to decide which pair of terminals will define a line.

The figure illustrates a system where terminal stations are shown as filled circles and the non-terminals are shown as empty circles. On the leftmost picture, if the pair (A,B) define a line and the pair (C,D) defines another, they won’t have any common station. But, on the rightmost picture, we can see that if the pairs (E,F) and (G,H) are chosen, those two lines will have two stations in common.

Given the description of the system and a sequence of Q queries consisting of two pairs of terminals, your program should compute, for each query, how many stations the two lines defined by those pairs have in common.

Input
The first line of the input contains two integers N (5≤N≤105) and Q (1≤Q≤20000), representing respectively the number of stations and the number of queries. The stations are numbered from 1 to N. Each of the following N−1 lines contains two distinct integers U and V (1≤U,V≤N), indicating that there is a tunnel between stations U and V. Each of the following Q lines contains four distinct integers A,B,C,D (1≤A,B,C,D≤N), representing a query: the two train lines are defined by pairs (A,B) and (C,D).

Output
For each query, your program must print a line containing an integer representing how many stations the two train lines defined by the query would have in common.

Examples
Input
5 1
1 5
2 5
5 3
5 4
1 2 3 4
Output
1
Input
10 4
1 4
4 5
3 4
3 2
7 3
6 7
7 8
10 8
8 9
6 10 2 5
1 9 5 10
9 10 2 1
5 10 2 9
Output
0
4
0
3

題意:

給你一棵樹,之後給你q個詢問,每次詢問給你a,b,c,d四個點,問你從a到b和從c到d這兩條線上有多少相同的點。

題解:

很明顯可以知道,這些線只與他們的lca有關,所以我們先處理出他們的lca,共有6個,先加到一個set裡面,之後for一遍這個set,看看這個lca是否在a到b的路上,是否在c到d的路上,怎麼判斷?我們可以知道,如果這個lca在a到b的路上,那麼這個點一定在a到 l c a ( a , b ) lca(a,b) 的路上,也就是(lca(a,p)==p&&lca(p,l1)==l1)||(lca(b,p)==p&&lca(p,l1)==l1),l1就是a,b的lca,做完之後看看有多少個點走過的次數是大於等於2的,就說明這些點是共同的點。如果剩下來的是0,那麼就不相交,如果是1,那就代表只有一個點相交,如果不判掉會re。我是判了這兩個特殊情況。之後有兩種情況,一種是一個鏈,還有一種是樹。在迴圈中我用任意兩個數比較,若是lca是其中的一個,就說明這兩個是在同一條鏈上的,就把lca去掉。因為我迴圈的條件是size()>2,所以要判鏈的情況,就是dep最小的lca是端點的情況。最後就是dep減一減,最後要注意因為最小的lca是沒有被算進去的,所以+1。

#include<bits/stdc++.h>
using namespace std;
#define maxn 500005
int head[maxn<<1],fa[maxn][25],vis[maxn],cnt,dep[maxn],dis[maxn];
struct node
{
    int to,next,wei;
}e[maxn<<1];
void add(int x,int y)
{
    e[cnt].to=y;
    e[cnt].next=head[x];
    head[x]=cnt++;
}
void bfs()
{
    fa[1][0]=1;
    dep[1]=0;
    dis[1]=0;
    queue<int>Q;
    Q.push(1);
    while(!Q.empty())
    {
        int u,v;
        u=Q.front();
        Q.pop();
        for(int i=1;i<=16;i++)
            fa[u][i]=fa[fa[u][i-1]][i-1];
        for(int i=head[u];~i;i=e[i].next)
        {
            v=e[i].to;
            if(v==fa[u][0])
                continue;
            dis[v]=dis[u]+e[i].wei;
            dep[v]=dep[u]+1;
            fa[v][0]=u;
            Q.push(v);
        }
    }
}
int lca(int x,int y){
    if(dep[x]<dep[y])swap(x,y);
//令x為深度較深的點
    for(int i=16;i>=0;i--)if(dep[fa[x][i]]>=dep[y])x=fa[x][i];
//讓x向上走到與y同一深度
    if(x==y)return x; //如果直接是lca,直接返回
    for(int i=16;i>=0;i--)if(fa[x][i]!=fa[y][i])x=fa[x][i],y=fa[y][i];
//x,y同時向上走,直到父節點相同
    return fa[x][0]; //返回父節點
}
set<int>st;
map<int,int>mp;
vector<int>fin;
int main()
{
    int n,q;
    scanf("%d%d",&n,&q);
    int l,r;
    memset(head,-1,sizeof(head));
    for(int i=1;i<n;i++)
        scanf("%d%d",&l,&r),add(l,r),add(r,l);
    bfs();
    while(q--)
    {
        st.clear();
        fin.clear();
        mp.clear();
        int a,b,c,d;
        scanf("%d%d%d%d",&a,&b,&c,&d);
        st.insert(lca(a,b)),st.insert(lca(a,c)),st.insert(lca(a,d));
        st.insert(lca(b,c)),st.insert(lca(b,d));
        st.insert(lca(c,d));
        int l1=lca(a,b),l2=lca(c,d);
        for(set<int>::iterator it=st.begin();it!=st.end();it++)
        {
            int p=*it;
            if((lca(a,p)==p&&lca(p,l1)==l1)||(lca(b,p)==p&&lca(p,l1)==l1))
                mp[p]++;
        }
        for(set<int>::iterator it=st.begin();it!=st.end();it++)
        {
            int p=*it;
            if((lca(c,p)==p&&lca(p,l2)==l2)||(lca(d,p)==p&&lca(p,l2)==l2))
                mp[p]++;
        }
        //int l3=dep[l1]<dep[l2]?l1:l2;
        //cout<<"l3: "<<l3<<endl;
        int ans=0;
        for(map<int,int>::iterator it=mp.begin();it!=mp.end();it++)
        {
            if(it->second>=2)
                fin.push_back(it->first);//,cout<<"lca: "<<it->first<<endl;
        }
        if(fin.size()==0)
        {
            printf("0\n");
            continue;
        }
        if(fin.size()==1)
        {
            printf("1\n");
            continue;
        }
        int l3,deep=1e9;
        for(int i=0;i<fin.size();i++)
        {
            if(dep[fin[i]]<deep)
                l3=fin[i],deep=dep[fin[i]];
        }
        while(fin.size()>2)
        {
            for(int i=0;i<fin.size()-1;i++)
            {
                for(int j=1;j<fin.size();j++)
                {
                    if(lca(fin[i],fin[j])==fin[i])
                        fin.erase(fin.begin()+i);
                    else if(lca(fin[i],fin[j])==fin[j])
                        fin.erase(fin.begin()+j);
                }
            }
        }
        if(lca(fin[0],fin[1])==fin[0])
            fin.erase(fin.begin());
        else if(lca(fin[0],fin[1])==fin[1])
            fin.erase(fin.begin()+1);
        for(int i=0;i<fin.size();i++)
            ans+=(dep[fin[i]]-dep[l3]);
            ans++;
        printf("%d\n",ans);
    }
    return 0;
}