2020暑假牛客多校9 B - Groundhog and Apple Tree (樹形dp)
阿新 • • 發佈:2020-08-15
2020暑假牛客多校9 B - Groundhog and Apple Tree (樹形dp)
題目大意:
給一個樹,走每條邊會減hp, 走到點會加hp,原地等待也會加hp, 問最少原地等待時間使得能夠遍歷所有點。每條邊最多走兩次。
題解:
首先每條邊最多走兩次那也就dfs一遍樹的過程,既然所有點都要走,那關鍵也就在與每次先走哪個子樹,即去考慮遍歷子樹的順序。
用time
表示子樹需要最小的等待時間,hp
表示遍歷子樹能夠得到的hp。那麼對於一個子樹有以下幾種情況:
(1) hp > time : 這時候把子樹遍歷一遍hp會增加,那肯定先處理這一類子樹
(2) hp < time: hp會減少 ,hp無法滿足遍歷所需等待時間
- 所以可以得到排序規則:
先處理hp > time的子樹,再處理hp < time的子樹,對於hp > time的子樹再按time從小到大排序,因為大家都能增加hp那我為什麼不把需要time大的子樹往後放,等前面time小的子樹處理完,hp多增加一些後再處理time大的子樹。這樣肯定更優。對於hp < time的子樹, 那麼我可以把這些子樹按hp從大到小排序。
程式碼:
#include<bits/stdc++.h> using namespace std; #define rep(i, a, n) for(int i = a; i <= n; ++ i); #define per(i, a, n) for(int i = n; i >= a; -- i); typedef long long ll; const int N = 2e6+ 5; const int mod = 998244353; const double Pi = acos(- 1.0); const int INF = 0x3f3f3f3f; const int G = 3, Gi = 332748118; ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; } ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; } ll lcm(ll a, ll b) { return a * b / gcd(a, b);} bool cmp(int a, int b){ return a > b;} // int T, n; ll val[N]; struct node1{ ll hp, time; }dp[N]; int head[N], cnt = 0; struct node{ int to, nxt;ll c; }edge[N << 1]; void add(int u, int v, ll w){ edge[cnt].to = v, edge[cnt].c = w, edge[cnt].nxt = head[u], head[u] = cnt ++; edge[cnt].to = u, edge[cnt].c = w, edge[cnt].nxt = head[v], head[v] = cnt ++; } bool cmp1(node1 a, node1 b){ if((a.hp > a.time) ^ (b.hp > b.time)) return (a.hp > a.time); if(a.hp > a.time) return a.time < b.time; return a.hp > b.hp; // if(a.hp > a.time){ // if(b.hp < b.time) return true; // else return a.time < b.time; // } // else{ // if(b.time > b.hp) return a.hp > b.hp; // else return false; // } } void dfs(int u, int pre){ vector<node1> sol; for(int i = head[u]; i != -1; i = edge[i].nxt){ int v = edge[i].to; ll w = edge[i].c; if(v == pre) continue; dfs(v, u); if(w >= dp[v].hp) dp[v].time += 2 * w - dp[v].hp, dp[v].hp = 0; else dp[v].time += w, dp[v].hp -= w; sol.push_back(dp[v]); } sort(sol.begin(), sol.end(), cmp1); int tt = sol.size(); ll minn = val[u], thp = val[u]; for(int i = 0; i < tt; ++ i){ ll hp = sol[i].hp, time = sol[i].time; minn = min(minn, thp - time); thp += hp - time; } if(minn >= 0) dp[u].time = 0, dp[u].hp = thp; else dp[u].time -= minn, dp[u].hp = thp - minn; } int main() { scanf("%d",&T); while(T --){ scanf("%d",&n); cnt = 0; for(int i = 1; i <= n; ++ i) { dp[i].time = dp[i].hp = 0; head[i] = -1; scanf("%lld",&val[i]); } for(int i = 1; i < n; ++ i){ int x, y; ll z; scanf("%d%d%lld",&x,&y,&z); add(x, y, z); } dfs(1, 0); printf("%lld\n",dp[1].time); } return 0; }