(洛谷)P4657 chase
阿新 • • 發佈:2021-07-21
基本思路 :
-
首先令a陣列表示該點的權值,c陣列表示該點所連線的所有點的權值和。
-
如果我們知道了該點的前驅,那該點的權值就為 $ c[now]-a[pre] $ 。
遞進 :
問題在於這是一棵無根樹,那我們可以任意定義一點為根,
然後設兩個陣列 $ s_{i,j} , d_{i,j} $ 分別表示從子樹 i 中某個點走到 i 和從 i 走到子樹 i 中某個點,放置了 j 次的最大收益。那問題就變成求樹的直徑了。
注意 :
因為我們需要知道一個點的前驅,所以 s陣列 是記錄該子樹的根節點的, d陣列 是不記錄該子樹的根節點的。
另外,起始點一定會放置磁鐵,因為如果逃亡者降落一個點後再跑到其他點放置磁鐵,那顯然不如直接降落在該點放置磁鐵更優。
最後一點,如果是起始點,因為其沒有前驅,所以該點的貢獻為 $ c[now] $ 。
樹的直徑 :
因為樹的路徑不能重複,所以我們可以列舉完一個 $ to $ 節點後先與 $ now $ 的 $ d $ 陣列和 $ s $ 陣列更新一邊答案,再將 $ to $ 節點的答案合併到 $ now $ 的答案中。
時間複雜度 :$ O_{(nv)} $
- code
#include <bits/stdc++.h> #define re register #define int long long #define db double #define pir make_pair using namespace std; const int maxn=100010; inline int read() { int s=0,w=1; char ch=getchar(); while(ch<'0'||ch>'9') { if(ch=='-') w=-1;ch=getchar(); } while(ch>='0'&&ch<='9') { s=s*10+ch-'0'; ch=getchar(); } return s*w; } int cnt,head[maxn],s[maxn][110],d[maxn][110],a[maxn],c[maxn]; struct EDGE { int nxt,var; } edge[maxn<<1]; inline void add(int a,int b) { edge[++cnt]=(EDGE){head[a],b};head[a]=cnt; } int n,v,ans; int sl[110],dl[110]; inline void dfs(int now,int pre) { s[now][1]=c[now]; for(re int i=head[now],to;i;i=edge[i].nxt) { if((to=edge[i].var)==pre) continue; dfs(to,now); int maxs=0,maxd=0; for(re int i=v;i>=2;i--) { sl[i]=max(s[to][i],c[now]-a[to]+s[to][i-1]); dl[i]=max(d[to][i],c[to]-a[now]+d[to][i-1]); maxs=max(maxs,s[now][v-i]); maxd=max(maxd,d[now][v-i]); ans=max(ans,sl[i]+maxd); ans=max(ans,dl[i]+maxs); } maxs=max(maxs,s[now][v-1]); maxd=max(maxd,d[now][v-1]); sl[1]=max(s[to][1],c[now]); dl[1]=max(d[to][1],c[to]-a[now]); ans=max(ans,sl[1]+maxd); ans=max(ans,dl[1]+maxs); s[now][1]=max(s[now][1],sl[1]); d[now][1]=max(d[now][1],dl[1]); maxs=max(maxs,s[now][v]); maxd=max(maxd,d[now][v]); ans=max(ans,sl[0]+maxd); ans=max(ans,dl[0]+maxs); for(re int i=1;i<=v;i++) { s[now][i]=max(s[now][i],sl[i]); d[now][i]=max(d[now][i],dl[i]); maxs=max(maxs,s[now][i]); maxd=max(maxd,d[now][i]); } ans=max(ans,max(maxs,maxd)); } } signed main(void) { // freopen("chase9.in","r",stdin); n=read(),v=read(); for(re int i=1;i<=n;i++) a[i]=read(); for(re int i=1,u,e;i<n;i++) { u=read(),e=read(); add(u,e); add(e,u); c[u]+=a[e]; c[e]+=a[u]; } if(!v) { printf("0"); return 0; } dfs(1,0); printf("%lld",ans); }