Bzoj 2286 & Luogu P2495 消耗戰(LCA+虛樹+尤拉序)
阿新 • • 發佈:2018-12-24
題面
題解
很容易想到$O(nk)$的樹形$dp$吧,設$f[i]$表示處理完這$i$顆子樹的最小花費,同時再設一個$mi[i]$表示$i$到根節點$1$路徑上的距離最小值。於是有:
$ f[i]=\sum min(f[son[i]], mi[son[i]]) $
這樣就有$40$分了。
考慮優化:這裡可以用虛樹來優化,先把所有點按照$DFS$序進行排序,然後將相鄰兩個點的$LCA$以及$1$號點加入進$LCA$,然後虛樹就構好了,考慮尤拉序的特殊性質,所以再還原出尤拉序,在上面做$dp$就好了。(xgzc告訴我可以再$dfs$一遍,但我不想寫了,尤拉序多好啊)
#include <cmath> #include <cstdio> #include <cstring> #include <algorithm> using std::min; using std::max; using std::swap; using std::sort; typedef long long ll; template<typename T> void read(T &x) { int flag = 1; x = 0; char ch = getchar(); while(ch < '0' || ch > '9') { if(ch == '-') flag = -flag; ch = getchar(); } while(ch >= '0' && ch <= '9') x = x * 10 + ch - '0', ch = getchar(); x *= flag; } const int N = 2.5e5 + 10, Inf = 1e9 + 7; int n, m, dfin[N], dfout[N], tim; int cnt, from[N], to[N << 1], nxt[N << 1]; int siz[N], son[N], dep[N], top[N], fa[N]; int nt[N << 1], vis[N], s[N << 1], tt; ll mi[N], f[N], dis[N << 1]; inline void addEdge(int u, int v, ll w) { to[++cnt] = v, dis[cnt] = w, nxt[cnt] = from[u], from[u] = cnt; } inline bool cmp(const int &x, const int &y) { int k1 = x > 0 ? dfin[x] : dfout[-x], k2 = y > 0 ? dfin[y] : dfout[-y]; return k1 < k2; } void dfs(int u) { siz[u] = 1, dfin[u] = ++tim, dep[u] = dep[fa[u]] + 1; for(int i = from[u]; i; i = nxt[i]) { int v = to[i]; if(v == fa[u]) continue; mi[v] = min(mi[u], dis[i]); fa[v] = u, dfs(v), siz[u] += siz[v]; if(siz[v] > siz[son[u]]) son[u] = v; } dfout[u] = ++tim; } void dfs(int u, int t) { top[u] = t; if(!son[u]) return ; dfs(son[u], t); for(int i = from[u]; i; i = nxt[i]) if(to[i] != son[u] && to[i] != fa[u]) dfs(to[i], to[i]); } int lca(int x, int y) { int fx = top[x], fy = top[y]; while(fx != fy) if(dep[fx] > dep[fy]) x = fa[fx], fx = top[x]; else y = fa[fy], fy = top[y]; return dep[x] < dep[y] ? x : y; } int main () { read(n); for(int i = 1; i < n; ++i) { int u, v; ll w; read(u), read(v), read(w); addEdge(u, v, w), addEdge(v, u, w); } mi[1] = Inf, dfs(1), dfs(1, 1), read(m); for(int i = 1; i <= m; ++i) { int tot; read(tot); for(int j = 1; j <= tot; ++j) read(nt[j]), vis[nt[j]] = true, f[nt[j]] = mi[nt[j]]; sort(&nt[1], &nt[tot + 1], cmp); for(int j = 1; j < tot; ++j) { int cf = lca(nt[j], nt[j + 1]); if(!vis[cf]) nt[++tot] = cf, vis[cf] = true; } int tmp = tot; for(int j = 1; j <= tmp; ++j) nt[++tot] = -nt[j]; if(!vis[1]) nt[++tot] = 1, nt[++tot] = -1; sort(&nt[1], &nt[tot + 1], cmp); for(int j = 1; j <= tot; ++j) if(nt[j] > 0) s[++tt] = nt[j]; else { int now = s[tt--]; if(now != 1) { int fat = s[tt]; f[fat] += min(f[now], mi[now]); } else printf("%lld\n", f[1]); f[now] = vis[now] = 0; } } return 0; }