1. 程式人生 > 實用技巧 >2020暑假牛客多校9 B - Groundhog and Apple Tree (樹形dp)

2020暑假牛客多校9 B - Groundhog and Apple Tree (樹形dp)

2020暑假牛客多校9 B - Groundhog and Apple Tree (樹形dp)

補題連結 參考部落格

題目大意:

給一個樹,走每條邊會減hp, 走到點會加hp,原地等待也會加hp, 問最少原地等待時間使得能夠遍歷所有點。每條邊最多走兩次。

題解:

首先每條邊最多走兩次那也就dfs一遍樹的過程,既然所有點都要走,那關鍵也就在與每次先走哪個子樹,即去考慮遍歷子樹的順序。

time 表示子樹需要最小的等待時間,hp 表示遍歷子樹能夠得到的hp。那麼對於一個子樹有以下幾種情況

​ (1) hp > time : 這時候把子樹遍歷一遍hp會增加,那肯定先處理這一類子樹

​ (2) hp < time: hp會減少 ,hp無法滿足遍歷所需等待時間

  • 所以可以得到排序規則

先處理hp > time的子樹,再處理hp < time的子樹,對於hp > time的子樹再按time從小到大排序,因為大家都能增加hp那我為什麼不把需要time大的子樹往後放,等前面time小的子樹處理完,hp多增加一些後再處理time大的子樹。這樣肯定更優。對於hp < time的子樹, 那麼我可以把這些子樹按hp從大到小排序。

程式碼:

#include<bits/stdc++.h>
using namespace std;
#define rep(i, a, n) for(int i = a; i <= n; ++ i);
#define per(i, a, n) for(int i = n; i >= a; -- i);
typedef long long ll;
const int N = 2e6+ 5;
const int mod = 998244353;
const double Pi = acos(- 1.0);
const int INF = 0x3f3f3f3f;
const int G = 3, Gi = 332748118;
ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; }
ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; }
ll lcm(ll a, ll b) { return a * b / gcd(a, b);}
bool cmp(int a, int b){ return a > b;}
//

int T, n;
ll val[N];
struct node1{
    ll hp, time;
}dp[N];

int head[N], cnt = 0;
struct node{
    int to, nxt;ll c;
}edge[N << 1];

void add(int u, int v, ll w){
    edge[cnt].to = v, edge[cnt].c = w, edge[cnt].nxt = head[u], head[u] = cnt ++;
    edge[cnt].to = u, edge[cnt].c = w, edge[cnt].nxt = head[v], head[v] = cnt ++;
}


bool cmp1(node1 a, node1 b){
    if((a.hp > a.time) ^ (b.hp > b.time)) return (a.hp > a.time);
    if(a.hp > a.time) return a.time < b.time;
    return a.hp > b.hp;
    
    
    // if(a.hp > a.time){
        // if(b.hp < b.time) return true;
        // else return a.time < b.time;
    // } 
    // else{
        // if(b.time > b.hp) return a.hp > b.hp;
        // else return false;
    // }
}

void dfs(int u, int pre){
    vector<node1> sol;
    for(int i = head[u]; i != -1; i = edge[i].nxt){
        int v = edge[i].to; ll w = edge[i].c;
        if(v == pre) continue;
        
        dfs(v, u);
        
        if(w >= dp[v].hp) dp[v].time += 2 * w - dp[v].hp, dp[v].hp = 0;
        else dp[v].time += w, dp[v].hp -= w;
        sol.push_back(dp[v]);
    }
    
    sort(sol.begin(), sol.end(), cmp1);
    
    int tt = sol.size();
    ll minn = val[u], thp = val[u];
    for(int i = 0; i < tt; ++ i){
        ll hp = sol[i].hp, time = sol[i].time;
        minn = min(minn, thp - time);
        thp += hp - time;
    }
    if(minn >= 0) dp[u].time = 0, dp[u].hp = thp;
    else dp[u].time -= minn, dp[u].hp = thp - minn;
}


int main()
{
    scanf("%d",&T);
    while(T --){
        scanf("%d",&n);
        cnt = 0;
        for(int i = 1; i <= n; ++ i) {
            dp[i].time = dp[i].hp = 0;
            head[i] = -1;
            scanf("%lld",&val[i]);
        }
        for(int i = 1; i < n; ++ i){
            int x, y; ll z; scanf("%d%d%lld",&x,&y,&z);
            add(x, y, z);
        }
        dfs(1, 0);
        printf("%lld\n",dp[1].time);
    }
    return 0;
}