[APIO2014]連珠線
注意到藍線的產生方法,可以發現藍線必然可以劃分成若干個三個點中間由藍線連成的情況,在樹上就只有下面兩種形式:
(從 \(\rm nofind\) 大佬哪裡擓來的圖)
那麼是不是我們只要吧原樹劃分成若干個互不相交的長度為 \(3\) 的鏈然後把這些點之間的邊染成藍色這樣的樹都是合法的呢?其實不是,比如下面這個情況:
按照上面的理論,我們可以選擇左下角和右下角的三個點作為藍線所在的鏈,但實際上這是不可能的,因為我們只能從一個珠子開始,所以如果要選擇左下角那三個點,那麼必須一步步地拉紅線過去,而最終一定會先拉到 \(3\) 號點,這樣就不能先連下面兩個點再讓 \(3\) 成為中間的的連線點。於是我們可以發現一個劃分藍邊的方案是合法的當且僅當向前面提到的第一種長度為 \(3\)
可以感覺到這個東西是不能貪心求解最優解的,那麼我們只能考慮 \(dp\) 了,但是如果我們要直接去做的話不僅狀態很難設計,轉移也非常繁雜。為了簡化狀態,一個常見的想法就是看能否將上面的兩種劃分方式看成一種,不難發現第二種劃分方案是更好來做的,於是我們可以考慮將第一種劃分方案轉化成第二種劃分方案。觀察一下這條鏈,第一種劃分方案實際上是兒子 \(\rightarrow\) 父親 \(\rightarrow\) 兒子,那麼我們只需要在上面紅色箭頭指向的點做為樹的根將樹倒過來,那麼原來的兒子 \(\rightarrow\)
\[dp_{u, 0} = \sum \max\{dp_{v, 0}, dp_{v, 1} + e_i.w\} \]
\[dp_{u, 1} = \max\{dp_{u, 0} - \max\{dp_{v, 0}, dp_{v, 1} + e_i.w\} + dp_{v, 0} + e_i.w\} \]
這樣就可以做到 \(O(n ^ 2)\) 了,當要列舉不同的根時一個常見的套路就是使用換根 \(dp\),讓我們來思考一下從 \(u\) 為根到 \(v\) 為根需要變換那些值。首先這裡換根時影響到的 \(dp\) 值實際上只有 \(u, v\) 兩個節點。\(dp_{u, 0}\) 只需減去 \(\max\{dp_{v, 0}, dp_{v, 1} + e_i.w\}\) 即可。\(dp_{u, 1}\) 的變化比較麻煩,因為可能出現 \(dp_{u, 1}\) 是從 \(v\) 轉移過來的情況,這樣 \(v\) 對答案的影響就可能有兩種,但我們可以提前預處理好 \(f_v\) 表示在以 \(fa_v\) 為根的子樹內出去 \(v\) 的 \(dp_{fa_v, 1}\) 的值,具體地我們只需在最初每次轉移時記錄最大值和次大值即可。但要注意的是每次在換根的時候 \(f_v\) 是會改變的,因此換根是我們還需再次 \(dp\) 一邊,然後使用棧記錄原來的 \(dp\) 值,結束換根後還原即可。對於 \(dp_{v, 0}\) 的更新,只需加上 \(\max\{dp_{u, 0}, dp_{u, 1} + e_i.w\}\) 即可;對於 \(dp_{v, 1}\) 的更新,將原來的樹中的最優點點的答案和 \(u\) 這個點最為最優點時的答案比較即可。
#include<bits/stdc++.h>
using namespace std;
#define rep(i, l, r) for(int i = l; i <= r; ++i)
#define Next(i, u) for(int i = h[u]; i; i = e[i].next)
const int N = 200000 + 5;
const int inf = 2000000000;
struct edge{
int v, next, w;
}e[N << 1];
long long ans, f[N], st[N], dp[N][2];
int n, u, v, w, tot, top, h[N];
int read(){
char c; int x = 0, f = 1;
c = getchar();
while(c > '9' || c < '0'){ if(c == '-') f = -1; c = getchar();}
while(c >= '0' && c <= '9') x = x * 10 + c - '0', c = getchar();
return x * f;
}
void add(int u, int v, int w){
e[++tot].v = v, e[tot].w = w, e[tot].next = h[u], h[u] = tot;
e[++tot].v = u, e[tot].w = w, e[tot].next = h[v], h[v] = tot;
}
void dfs1(int u, int fa){
dp[u][0] = 0;
Next(i, u){
int v = e[i].v; if(v == fa) continue;
dfs1(v, u), dp[u][0] += max(dp[v][0], dp[v][1] + e[i].w);
}
long long mx = -inf, se = -inf, tmp, p = 0;
Next(i, u){
int v = e[i].v; if(v == fa) continue;
tmp = dp[u][0] - max(dp[v][0], dp[v][1] + e[i].w) + dp[v][0] + e[i].w;
if(tmp > mx) se = mx, mx = tmp, p = v;
else if(tmp > se) se = tmp;
}
dp[u][1] = mx;
Next(i, u){
int v = e[i].v; if(v == fa) continue;
f[v] = (p == v ? se : mx);
}
}
void dfs2(int u, int fa){
ans = max(ans, dp[u][0]);
Next(i, u){
int v = e[i].v; if(v == fa) continue;
long long fdp0 = dp[u][0], fdp1 = dp[u][1], sdp0 = dp[v][0], sdp1 = dp[v][1];
dp[u][0] -= max(dp[v][0], dp[v][1] + e[i].w);
dp[u][1] = f[v] - max(dp[v][0], dp[v][1] + e[i].w);
dp[v][0] += max(dp[u][0], dp[u][1] + e[i].w);
dp[v][1] = max(dp[v][1] + max(dp[u][0], dp[u][1] + e[i].w), sdp0 + dp[u][0] + e[i].w);
int fir = top;
Next(j, v){
int nxt = e[j].v; if(nxt == v) continue;
st[++top] = f[nxt], f[nxt] = max(f[nxt] + max(dp[u][0], dp[u][1] + e[i].w), dp[u][0] + e[i].w + sdp0);
}
dfs2(v, u), dp[u][0] = fdp0, dp[u][1] = fdp1, dp[v][0] = sdp0, dp[v][1] = sdp1;
reverse(st + fir + 1, st + top + 1);
Next(j, v) if(e[j].v != v) f[e[j].v] = st[top--];
}
}
int main(){
n = read();
rep(i, 1, n - 1) u = read(), v = read(), w = read(), add(u, v, w);
dfs1(1, 0), dfs2(1, 0);
printf("%lld", ans);
return 0;
}