1. 程式人生 > 實用技巧 >多詢問 樹上距離A,B相等的點數。

多詢問 樹上距離A,B相等的點數。

題目

A and B and Lecture Rooms(https://ac.nowcoder.com/acm/problem/110856)

題目大意

給你一棵有n個節點的樹。
有m個詢問。每次詢問給你一個A,B。問樹上到A,B節點距離相等的點有多少個。

輸入

第一行包含整數n(1 ≤ n ≤ 10^5。
接下來的n-1線描述了走廊。第i行(1≤ i≤≤ n - 1)包含兩個整數ai和bi(1≤ ai, bi≤ n),ai和bi有一條邊。
下一行包含整數m(1 ≤ m ≤ 105)-查詢數。
接下來的m行描述查詢。每行包含兩個整數A和B(1≤ A, B≤ n)

樣例

輸入
4
1 2
2 3
2 4
2
1 2
1 3
輸出
0
2

輸出

對於每個詢問輸出有多少個節點到A和B的距離相等。

思路

我們分析如果存在這個點,這個點一定有一個在A-B的路徑上。
如果路徑長度為奇數就輸出0。
如果路徑長度為偶數。我們分析:
1.假設A的深度>B的深度

A~B的距離為2k,那麼這個點一定是A的第k級祖先C。
可以滿足的點就是C的子節點個數-A所在的鏈的所有節點(就是A的k-1祖先(4)的子樹大小)。
\(siz[lca(A, k)]-siz[lca(A, k-1)]\)
2.假設假設A的深度=B的深度

這個就是\(n-siz[lca(A, k-1)]-siz[lca(B, k-1)]\)

#include <bits/stdc++.h>
using namespace std;

vector<int> G[500005];
struct LCA{
    int d[500005], fa[500005][22], lg[500005], siz[500005];
    void init(int n, int root){//預處理
        for(int i = 1; i <= n; ++i){
            lg[i] = lg[i-1] + (1 << lg[i-1] == i);
        }
        dfs(root, 0);
    }
    void dfs(int now, int father){
        fa[now][0]=father; d[now]=d[father]+1; siz[now]=1;
        for(int i=1; i<=lg[d[now]]; ++i){
            fa[now][i]=fa[fa[now][i-1]][i-1];
        }
        for(auto x: G[now]){
            if(x!=father){
                dfs(x, now); siz[now]+=siz[x];
            }
        }
    }
    int lca(int x, int y){//LCA
        if(d[x]<d[y]) swap(x, y);
        while(d[x]>d[y]){
            x=fa[x][lg[d[x]-d[y]]-1];
        }
        if(x==y) return x;
        for(int k=lg[d[x]]-1; k>=0; --k){
            if(fa[x][k]!=fa[y][k]){
                x=fa[x][k], y=fa[y][k];
            }
        }
        return fa[x][0];
    }
    int dis(int a, int b){
        return d[a]+d[b]-2*d[lca(a, b)];
    }
    int getk(int f, int k){
        for(int i=20; i>=0; i--){
            if(k&(1<<i)){
                f=fa[f][i];
            }
        }
        return f;
    }
}lca;

int main(){
    int n, m, x, y;
    scanf("%d", &n);
    for(int i=2; i<=n; i++){
        scanf("%d%d", &x, &y);
        G[x].push_back(y);
        G[y].push_back(x);
    }
    lca.init(n, 1);
    scanf("%d", &m);
    while(m--){
        scanf("%d%d", &x, &y);
        int dis=lca.dis(x, y);
        if(x==y){
            printf("%d\n", n);
            continue;
        }
        if(dis%2){
            printf("%d\n", 0);
        }
        else{
            dis/=2;
            int f, ans=0;
            if(lca.d[x]<lca.d[y]){
                f=lca.getk(y, dis-1);
                ans=lca.siz[lca.fa[f][0]]-lca.siz[f];
            }
            else if(lca.d[x]>lca.d[y]){
                f=lca.getk(x, dis-1);
                ans=lca.siz[lca.fa[f][0]]-lca.siz[f];
            }
            else{
                int fa=lca.getk(x, dis-1);
                int fb=lca.getk(y, dis-1);
                ans=n-lca.siz[fa]-lca.siz[fb];
            }
            printf("%d\n", ans);
        }
    }

    return 0;
}