2020 Multi-University Training Contest 2 In Search of Gold
2020 Multi-University Training Contest 2 In Search of Gold
題目大意:
給你一顆大小是n的樹,每一條邊都有兩個值一個a一個b,選擇k條a邊,n-1-k 條b邊,問這棵樹的直徑最短是多少?
題解:
很自然的定義 \(dp[i][j]\) 表示對於子樹 \(i\) ,有 \(j\) 條邊來自 a 的最遠距離的最小值。
那麼轉移方程就是
\(dp[u][j]=min(dp[u][j],max(dp[v][x]+a,dp[u][j-x-1]),max(dp[v][x]+b,dp[u][j-x]))\)
這個轉移方程表示的是,如果從子節點v轉移,那麼有兩種選擇,一種就是這個子節點要這個條a邊,一個是這個子節點不要這條a邊的轉移。
但是呢,這樣轉移會出現一個問題。
對於這棵樹,假設k=1,括號左邊表示a,右邊表示b,到3這個點有兩種可能選擇(如果選了一條a邊) (5,1) (4,4)
所有按照上面的轉移方程3這個點如果選一條a邊的結果是 (4,4) 。
這樣的話,那麼上面3到4這條邊是 (1000,1) 那麼直徑是不是 4+4=8,那麼沒有下面選擇 (5,1) 直徑是 5+1 更優,如果下面選擇 (5,1) ,那麼3到4這條邊是(1000,100) 那麼也會出現問題。
所以如果直接這樣轉移肯定會出問題的,那怎麼轉移是對的呢?
先思考一下為什麼這個會影響結果?其實就是因為子樹可能成為直徑,如果保證子樹直徑小於等於整棵樹的直徑,那麼是不是就沒什麼影響了,但是怎麼判斷子樹的直徑有沒有大於整棵樹的直徑呢?這個可以二分求解,所以二分一下這個直徑長度,如果這個長度子節點合併的時候大於這個直徑長度,那麼就不從這裡轉移即可,不從這裡轉移表示不要這個狀態。
最後怎麼判斷這個check是否為真?這個很好判斷,因為 \(dp[i][j]\) 的轉移必須要求以 \(i\) 為根節點的樹的直徑都小於等於這個mid。
#include <bits/stdc++.h> #define inf 0x3f3f3f3f #define inf64 0x3f3f3f3f3f3f3f3f using namespace std; typedef long long ll; const int maxn = 2e4+10; int head[maxn],nxt[maxn<<1],to[maxn<<1],cnt,a[maxn<<1],b[maxn<<1]; void add(int u,int v,int x,int y){ ++cnt,to[cnt]=v,nxt[cnt]=head[u],a[cnt]=x,b[cnt]=y,head[u]=cnt; ++cnt,to[cnt]=u,nxt[cnt]=head[v],a[cnt]=x,b[cnt]=y,head[v]=cnt; } ll dp[maxn][22]; //dp[i][j] 表示已i為根節點的子樹,選擇了j條a邊,最遠距離最小。 int n,k; int siz[maxn]; ll tmp[22]; //tmp[i] 表示選了i條a邊的最遠距離 void dfs(int u,int pre,ll x){ dp[u][0]=siz[u]=0; for(int i=head[u];i;i=nxt[i]){ int v = to[i]; if(v == pre) continue; dfs(v,u,x); int num = min(k,siz[u]+siz[v]+1); for(int j=0;j<=num;j++) tmp[j]=x+1; for(int j=0;j<=siz[u];j++){ for(int h=0;h<=siz[v]&&h+j<=k;h++){ if(dp[u][j]+dp[v][h]+a[i]<=x){ tmp[j+h+1]=min(tmp[j+h+1],max(dp[u][j],dp[v][h]+a[i])); } if(dp[u][j]+dp[v][h]+b[i]<=x){ tmp[j+h]=min(tmp[j+h],max(dp[u][j],dp[v][h]+b[i])); } } } siz[u]=num; for(int j=0;j<=siz[u];j++) dp[u][j]=tmp[j]; } } bool check(ll x){ dfs(1,0,x); if(dp[1][k]<=x) return true; return false; } void init(int n){ cnt=0; for(int i=0;i<=n;i++) head[i]=0; } int main(){ int t; scanf("%d",&t); while(t--){ scanf("%d%d",&n,&k); init(n); for(int i=1;i<n;i++){ int u,v,x,y; scanf("%d%d%d%d",&u,&v,&x,&y); add(u,v,x,y); } ll l=1,r=inf64,ans=0; while(l<=r){ ll mid=(l+r)>>1ll; if(check(mid)) ans=mid,r=mid-1; else l=mid+1; } printf("%lld\n",ans); } }