1. 程式人生 > >洛谷 P3233 [HNOI2014]世界樹(虛樹+dp)

洛谷 P3233 [HNOI2014]世界樹(虛樹+dp)

題面

luogu

題解

資料範圍已經告訴我們是虛樹了,考慮如何在虛樹上面\(dp\)

以下摘自hzwer部落格:

構建虛樹以後兩遍dp處理出虛樹上每個點最近的議事處

然後列舉虛樹上每一條邊,考慮其對兩端點的答案貢獻

可以用倍增二分出分界點

如果a,b的分界點為mid,a,b路徑上a的第一個兒子為x

則對a的貢獻是size[x]-size[mid]

對b的貢獻是size[mid]-size[b]

還要算上沒被考慮的點

Code

// luogu-judger-enable-o2
#include<bits/stdc++.h>

#define LL long long
#define RG register

using namespace std;
template<class T> inline void read(T &x) {
    x = 0; RG char c = getchar(); bool f = 0;
    while (c != '-' && (c < '0' || c > '9')) c = getchar(); if (c == '-') c = getchar(), f = 1;
    while (c >= '0' && c <= '9') x = x*10+c-48, c = getchar();
    x = f ? -x : x;
    return ;
}
template<class T> inline void write(T x) {
    if (!x) {putchar(48);return ;}
    if (x < 0) x = -x, putchar('-');
    int len = -1, z[20]; while (x > 0) z[++len] = x%10, x /= 10;
    for (RG int i = len; i >= 0; i--) putchar(z[i]+48);return ;
}
int n;
const int N = 300010;
struct node {
    int to, next;
}g[N<<1];
int last[N], gl;
inline void add(int x, int y) {
    g[++gl] = (node) {y, last[x]};
    last[x] = gl;
    return ;
}
int dfn[N], cnt, siz[N], dep[N], anc[N][21], rem[N], bel[N];
void init(int u, int fa) {
    dfn[u] = ++cnt; siz[u] = 1;
    anc[u][0] = fa;
    for (int i = 1; i <= 20; i++)
        anc[u][i] = anc[anc[u][i-1]][i-1];
    for (int i = last[u]; i; i = g[i].next) {
        int v = g[i].to;
        if (v == fa) continue;
        dep[v] = dep[u]+1;
        init(v, u);
        siz[u] += siz[v];
    }
    return ;
}
int lca(int x, int y) {
    if (dep[x] < dep[y]) swap(x, y);
    for (int i = 20; i >= 0; i--)
        if (dep[x]-(1<<i) >= dep[y])
            x = anc[x][i];
    if (x == y) return x;
    for (int i = 20; i >= 0; i--)
        if (anc[x][i] != anc[y][i])
            x = anc[x][i], y = anc[y][i];
    return anc[x][0];
}
int dis(int x, int y) {
    return dep[x]+dep[y]-2*dep[lca(x, y)];
}
int top, len, m, a[N], b[N], s[N], c[N], f[N];
bool cmp(int a, int b) {
    return dfn[a] < dfn[b];
}

inline void insert(int x) {
    if (top == 1) {s[++top] = x; return ;}
    int o = lca(x, s[top]);
    while (top > 1 && dfn[s[top-1]] >= dfn[o]) add(s[top-1], s[top]), top--;
    if (o != s[top]) add(o, s[top]), s[top] = o;
    s[++top] = x;
    return ;
}

void dfs1(int x) {
    c[++len] = x; rem[x] = siz[x];
    for (int i = last[x]; i; i = g[i].next) {
        dfs1(g[i].to);
        if (!bel[g[i].to]) continue;
        int t1 = dis(bel[g[i].to], x), t2 = dis(bel[x], x);
        if ((t1 == t2 && bel[g[i].to] < bel[x]) || t1 < t2 || !bel[x])
            bel[x] = bel[g[i].to];
    }
    return ;
}
void dfs2(int x) {
    for (int i = last[x]; i; i = g[i].next) {
        int t1 = dis(bel[x], g[i].to), t2 = dis(bel[g[i].to], g[i].to);
        if ((t1 == t2 && bel[g[i].to] > bel[x]) || t1 < t2 || !bel[g[i].to])
            bel[g[i].to] = bel[x];
        dfs2(g[i].to);
    }
    return ;
}

void solve(int a, int b) {
    int x = b, mid = b;
    for (int i = 20; i >= 0; i--)
        if (dep[anc[x][i]] > dep[a])
            x = anc[x][i];
    rem[a] -= siz[x];
    if (bel[a] == bel[b]) {
        f[bel[a]] += siz[x]-siz[b];
        return ;
    }
    for (int i = 20; i >= 0; i--) {
        int nxt = anc[mid][i];
        if (dep[nxt] <= dep[a]) continue;
        int t1 = dis(bel[a], nxt), t2 = dis(bel[b], nxt);
        if (t1 > t2 || (t1 == t2 && bel[b] < bel[a])) mid = nxt;
    }
    f[bel[a]] += siz[x]-siz[mid];
    f[bel[b]] += siz[mid]-siz[b];
    return ;
}

void query() {
    top = len = gl = 0;
    read(m);
    for (int i = 1; i <= m; i++) read(a[i]), b[i] = a[i];
    for (int i = 1; i <= m; i++) bel[a[i]] = a[i];
    sort(a+1, a+1+m, cmp);
    if (bel[1] != 1) s[++top] = 1;
    for (int i = 1; i <= m; i++) insert(a[i]);
    for (int i = 1; i < top; i++) add(s[i], s[i+1]);
    dfs1(1); dfs2(1);
    for (int i = 1; i <= len; i++)
        for (int j = last[c[i]]; j; j = g[j].next)
            solve(c[i], g[j].to);
    for (int i = 1; i <= len; i++) f[bel[c[i]]] += rem[c[i]];
    for (int i = 1; i <= m; i++) write(f[b[i]]), putchar(' ');
    putchar('\n');
    for (int i = 1; i <= len; i++) f[c[i]] = bel[c[i]] = last[c[i]] = 0;
    return ;
}

int main() {
    read(n);
    for (int i = 1; i < n; i++) {
        int x, y;
        read(x); read(y);
        add(x, y); add(y, x);
    }
    init(1, 0);
    memset(last, 0, sizeof(last));
    int q; read(q);
    while (q--) query();
    return 0;
}