1. 程式人生 > 實用技巧 >MMSet2-【樹的直徑+LCA】

MMSet2-【樹的直徑+LCA】

題意

給定一棵 \(n\) 個節點的樹,點編號為 \(1...n\)\(Q\) 次詢問,每次詢問給定一個點集 \(S\),令 \(f(u)=\max\limits_{v\in S}dist(u,v)\) ,你需要求出\(\min\limits_{u=1...n}f(u)\)。其中 \(dist(u,v)\) 表示樹上路徑 \((u,v)\) 的邊數。

連結:https://ac.nowcoder.com/acm/contest/7141/A

分析

題意轉化為找出一個點,使得其到點集中的點的距離的最大值最小。對於點集組成的樹而言,該點肯定在其直徑的中點左右,那麼答案就是 \(\lceil \frac{直徑}{2} \rceil\)

。因此,需要求出直徑的長度。因為點集中深度最深的點一定是直徑的一個端點,因此,只要找到該點,然後遍歷其它的點,通過 \(LCA\) 求出兩點間距離,取最大值即可。

\(LCA\) 通過 \(ST\) 表實現,這樣可以實現 \(O(1)\) 查詢,用倍增可能會超時。

程式碼

#include <bits/stdc++.h>
#define pb push_back
using namespace std;
const int N=3e5+5;
const int M=1e6+6;
const int mak=25;
vector<int>pic[N];
int depth[N],f[N<<1][mak],rnk[N],n,cnt;
int s[M];
void read(int &x)
{
    x=0;
    int f=1;
    char ch=getchar();
    while(ch<'0'||ch>'9')
    {
        if(ch=='-') f=-1;
        ch=getchar();
    }
    while(ch>='0'&&ch<='9')
    {
        x=x*10+ch-'0';
        ch=getchar();
    }
    x*=f;
}
void dfs(int u,int p,int d)
{
    depth[u]=d;
    rnk[u]=++cnt;
    f[cnt][0]=u;
    for(int i=0;i<pic[u].size();i++)
    {
        int v=pic[u][i];
        if(v==p) continue;
        dfs(v,u,d+1);
        f[++cnt][0]=u;
    }
}
void init()
{
    dfs(1,0,0);
    int len=(int)log2(1.0*cnt);
    for(int k=1;k<=len;k++)
    {
        for(int i=1;i+(1<<k)-1<=cnt;i++)
        {
            int a=f[i][k-1],b=f[i+(1<<(k-1))][k-1];
            if(depth[a]<depth[b]) f[i][k]=a;
            else f[i][k]=b;
        }
    }
}
int lca(int u,int v)
{
    int l=rnk[u],r=rnk[v];
    if(l>r) swap(l,r);
    int len=(int)log2(1.0*(r-l+1));
    int a=f[l][len],b=f[r-(1<<len)+1][len];
    if(depth[a]<depth[b]) return a;
    else return b;
}
int get_dis(int u,int v)
{
    return depth[u]+depth[v]-2*depth[lca(u,v)];
}
int main()
{
    int x,y,m,q;
    read(n);
    for(int i=1;i<n;i++)
    {
        read(x),read(y);
        pic[x].pb(y);
        pic[y].pb(x);
    }
    init();
    read(q);
    while(q--)
    {
        read(m);
        int ans=0,dm=0,d=0;
        for(int i=1;i<=m;i++)
        {
            read(s[i]);
            if(depth[s[i]]>dm)
                dm=depth[s[i]],d=s[i];
        }
        for(int i=1;i<=m;i++)
        {
            if(s[i]==d) continue;
            ans=max(ans,get_dis(d,s[i]));
        }
        printf("%d\n",(ans+1)/2);
    }
    return 0;
}