1. 程式人生 > 實用技巧 >每日一題 8月18日MMSet2 求樹上集合點的最大距離最小

每日一題 8月18日MMSet2 求樹上集合點的最大距離最小

題目

MMSet2(https://ac.nowcoder.com/acm/problem/14250)

題目描述

給定一棵n個節點的樹,點編號為1…n。
Q次詢問,每次詢問給定一個點集S,

其中dist(u,v)表示樹上路徑(u,v)的邊數。

輸入描述:

第一行一個整數n,接下來n−1行每行兩個整數表示樹上的一條邊。
接下來一行一個整數Q,接著Q行,每行第一個數是|S|,剩下|S|個互不相同的數代表這個集合。

輸出描述:

輸出Q行,每行一個整數表示答案。

輸入

3
1 2
1 3
1
2 2 3

輸出

1

備註:

n≤3×105,|S|≥1,∑|S|≤106

思路:

題目求樹上任意選擇一點到|S|集合的點的最大距離最小。這個點應該是|S|中距離最遠點的點對的中點。
假設為x,y。ans=(dis(x, y)+1)/2
現在怎麼求|S|的直徑。
先以1為根維護出每個點的深度,那麼S集合裡面深度最深的一個點就是最長距離一個端點。
然後在|S|列舉另外一個點就可以了。這個用樹上LCA就可以了。

#pragma GCC optimize(3, "Ofast", "inline")
#pragma GCC target("avx,avx2,fma")
#pragma GCC optimization ("unroll-loops")

#include <bits/stdc++.h>
using namespace std;

#define LL long long
#define rint register int

#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<20,stdin),p1==p2)?EOF:*p1++)
char buf[1<<20],*p1=buf,*p2=buf;
inline int read() {
    int f=0,fu=1;
    char c=getchar();
    while(c<'0'||c>'9') {
        if(c=='-')
            fu=-1;
        c=getchar();
    }
    while(c>='0'&&c<='9') {
        f=(f<<3)+(f<<1)+c-48;
        c=getchar();
    }
    return f*fu;
}

inline void print(LL x) {
    if (x < 0) {
        putchar('-');
        x = -x;
    }
    if (x >= 10)
        print(x / 10);
    putchar(x % 10 + '0');
}

const int N = 300015;

struct Edge {
    int to, nxt, val;
} edge[N << 1];


struct RMQ_lca {

    LL w[N];
    int head[N], etot;
    int lg[N << 1], a[N << 1], dfn[N], dep[N], tot;
    int f[N << 1][20];

    void init(int n) {
        etot=tot=0;
        memset(head, 0, sizeof(head[0])*(n+5));
        memset(w, 0, sizeof(w[0])*(n+5));
    }

    void add(int u, int v, int w) {
        edge[++etot] = {v, head[u], w};
        head[u] = etot;
    }

    void dfs(int u, int fa) {
        f[++tot][0] = u, dfn[u] = tot, dep[u] = dep[fa] + 1;
        for(rint i = head[u]; i; i = edge[i].nxt) {
            int v = edge[i].to;
            if (v == fa)
                continue;
            w[v] = w[u] + edge[i].val;
            dfs(v, u);
            f[++tot][0] = u;
        }
    }
    void pre() {
        lg[1] = 0;
        for (rint i = 2; i <= tot; i++)
            lg[i] = lg[i >> 1] + 1;
        for (rint j = 1; j <= 19; j++) {
            for (rint i = 1; i + (1 << j) - 1 <= tot; i++) {
                if (dep[f[i][j - 1]] < dep[f[i + (1 << j - 1)][j - 1]]) {
                    f[i][j] = f[i][j - 1];
                } else {
                    f[i][j] = f[i + (1 << j - 1)][j - 1];
                }
            }
        }
    }
    int LCA(int u, int v) {
        u = dfn[u], v = dfn[v];
        if (u > v)
            swap(u, v);
        int len = lg[v - u + 1];
        if (dep[f[u][len]] < dep[f[v - (1 << len) + 1][len]]) {
            return f[u][len];
        } else {
            return f[v - (1 << len) + 1][len];
        }
    }
    LL dist(int u, int v) {
        //cout<<u<<" "<<v<<" "<<LCA(u, v)<<endl;
        return w[u] + w[v] - 2ll * w[LCA(u, v)];
    }

} lca;

vector<int> v;
int main() {

    int n=read();
    lca.init(n);
    for(int i=1; i<n; i++) {
        int x, y;
        x=read(), y=read();
        lca.add(x, y, 1);
        lca.add(y, x, 1);
    }
    lca.dfs(1, 0);
    lca.pre();
    int q=read();
    while(q--) {
        int m=read();
        v.clear();
        int mx=0, ans=0, x;
        while(m--) {
            x=read();
            v.push_back(x);
            if(lca.dep[mx]<lca.dep[x]) {
                mx=x;
            }
        }
        for(auto x: v) {
            //cout<<x<<" "<<mx<<" "<<((lca.dist(x, mx)+1)/2)<<endl;
            ans=max(1ll*ans, (lca.dist(x, mx)+1)/2);
        }
        print(ans); putchar('\n');
    }

    return 0;
}