hdu7024 Penguin Love Tour(2021杭電暑假多校5)樹形dp
阿新 • • 發佈:2021-08-04
題意
給定一棵\(n\)個點的樹,樹的每個邊有個權值\(w\),每個點有個權值\(p\)。每個點可以把相鄰的某一條邊邊權減\(p\)。最小化直徑。(\(1\le n,w\le{10}^5,0\le p\le{10}^5\))
思路
考慮二分答案,設為\(limit\)。那麼\(check\)就是每棵子樹最大的兩條邊之和不能超過\(limit\)。設\(dp[u][0]\)為節點\(u\)這棵子樹沒有使用\(u\)時,某個葉子到\(u\)的最長路徑的最小值。\(dp[u][1]\)為已經使用了\(u\)的最小值。那麼有:
$ dp[u][0]=max_{v}{min(dp[v][0]+max(0,w_{u,v}-p[v]),dp[v][1]+w_{u,v})} \tag{1}$
\(dp[u][1]=min_{v_0}\{max(min(dp[v_0][0]+max(0,w_{u,v_0}-p[v_0]-p[u]),dp[v_0][1]+max(0,w_{u,v_0}-p[u])),\\{max_{v\not=v_0}\{min(dp[v][0]+max(0,w_{u,v}-p[v]),dp[v][1]+w_{u,v})\})}\} \tag{2}\)
然後又因為對兒子用了\(p[u]\)後最長的兒子一定會在\((1)\)中最長的三個中取,那麼求\(dp[u][1]\)只需要列舉\(v_0\)為\((1)\)中最大的三個即可。
程式碼
#include <bits/stdc++.h> using namespace std; using ll=long long; using pii=pair<int,int>; using pli=pair<ll,int>; constexpr ll inf=1e18; inline char gc() { static constexpr int BufferSize = 1 << 22 | 5; static char buf[BufferSize], *p, *q; static std::streambuf *i = std::cin.rdbuf(); return p == q ? p = buf, q = p + i->sgetn(p, BufferSize), p == q ? EOF : *p++ : *p++; } struct Reader { template <class T> Reader &operator>>(T &w) { char c, p = 0; for (; !std::isdigit(c = gc());) if (c == '-') p = 1; for (w = c & 15; std::isdigit(c = gc()); w = w * 10 + (c & 15)) ; if (p) w = -w; return *this; } } fin; template<int N> struct Max{ int n=0; array<pli,N> a; void insert(pli x) { if(n!=0) for(int i=0;i<n;i++) { if(a[i]<x) swap(a[i],x); } if(n<N) a[n++]=x; } void erase(int id) { for(int i=0;i<n;i++) { if(a[i].second==id) { for(int j=i;j<n-1;j++) a[j]=a[j+1]; n--; break; } } } ll sum(int cnt) { cnt=min(cnt,N); ll ans=0; for(int i=0;i<cnt;i++) ans+=a[i].first; return ans; } bool vis(int id) { for(int i=0;i<n;i++) if(a[i].second==id) return true; return false; } }; void solve() { int n; ll L=0,R=0; fin>>n; vector<int> p(n+1); vector<vector<pii>> g(n+1); for(int i=1;i<=n;i++) fin>>p[i]; for(int i=1,u,v,w;i<=n-1;i++) { fin>>u>>v>>w; g[u].push_back({v,w}); g[v].push_back({u,w}); R+=w; } vector<ll>dp[2]; ll mid; bool flag; function<void(int,int)> dfs=[&](int u,int f) { int son=0; Max<3>s; for(int i=0;i<g[u].size();i++) { int v=g[u][i].first; int w=g[u][i].second; if(v==f) continue; dfs(v,u); if(!flag)return; son++; s.insert({min(dp[0][v]+max(w-p[v],0),dp[1][v]+w),i}); } if(son==0) { dp[0][u]=0; return; } if(s.sum(2)<=mid) dp[0][u]=s.sum(1); dp[1][u]=inf; for(int i=0;i<g[u].size();i++) { int v=g[u][i].first; int w=g[u][i].second; if(v==f || !s.vis(i)) continue; Max<3> s1=s; s1.erase(i); s1.insert({min(dp[0][v]+max(w-p[v]-p[u],0),dp[1][v]+max(w-p[u],0)),i}); if(s1.sum(2)<=mid) dp[1][u]=min(dp[1][u],s1.sum(1)); } if(dp[0][u]==inf && dp[1][u]==inf) flag=false; }; while(L<R) { mid=(L+R)/2; flag=true; dp[0]=dp[1]=vector<ll>(n+1,inf); dfs(1,0); if(flag && (dp[0][1]<=mid || dp[1][1]<=mid)) R=mid; else L=mid+1; } cout<<L<<'\n'; } int main() { ios::sync_with_stdio(false); cin.tie(nullptr); int T; fin>>T; while(T--) solve(); return 0; }