1. 程式人生 > 實用技巧 >2020 Multi-University Training Contest 2 In Search of Gold

2020 Multi-University Training Contest 2 In Search of Gold

2020 Multi-University Training Contest 2 In Search of Gold

題目大意:

給你一顆大小是n的樹,每一條邊都有兩個值一個a一個b,選擇k條a邊,n-1-k 條b邊,問這棵樹的直徑最短是多少?

題解:

很自然的定義 \(dp[i][j]\) 表示對於子樹 \(i\) ,有 \(j\) 條邊來自 a 的最遠距離的最小值。

那麼轉移方程就是

\(dp[u][j]=min(dp[u][j],max(dp[v][x]+a,dp[u][j-x-1]),max(dp[v][x]+b,dp[u][j-x]))\)

這個轉移方程表示的是,如果從子節點v轉移,那麼有兩種選擇,一種就是這個子節點要這個條a邊,一個是這個子節點不要這條a邊的轉移。

但是呢,這樣轉移會出現一個問題。

對於這棵樹,假設k=1,括號左邊表示a,右邊表示b,到3這個點有兩種可能選擇(如果選了一條a邊) (5,1) (4,4)

所有按照上面的轉移方程3這個點如果選一條a邊的結果是 (4,4) 。

這樣的話,那麼上面3到4這條邊是 (1000,1) 那麼直徑是不是 4+4=8,那麼沒有下面選擇 (5,1) 直徑是 5+1 更優,如果下面選擇 (5,1) ,那麼3到4這條邊是(1000,100) 那麼也會出現問題。

所以如果直接這樣轉移肯定會出問題的,那怎麼轉移是對的呢?

先思考一下為什麼這個會影響結果?其實就是因為子樹可能成為直徑,如果保證子樹直徑小於等於整棵樹的直徑,那麼是不是就沒什麼影響了,但是怎麼判斷子樹的直徑有沒有大於整棵樹的直徑呢?這個可以二分求解,所以二分一下這個直徑長度,如果這個長度子節點合併的時候大於這個直徑長度,那麼就不從這裡轉移即可,不從這裡轉移表示不要這個狀態。

最後怎麼判斷這個check是否為真?這個很好判斷,因為 \(dp[i][j]\) 的轉移必須要求以 \(i\) 為根節點的樹的直徑都小於等於這個mid。

#include <bits/stdc++.h>
#define inf 0x3f3f3f3f
#define inf64 0x3f3f3f3f3f3f3f3f
using namespace std;
typedef long long ll;
const int maxn = 2e4+10;
int head[maxn],nxt[maxn<<1],to[maxn<<1],cnt,a[maxn<<1],b[maxn<<1];
void add(int u,int v,int x,int y){
    ++cnt,to[cnt]=v,nxt[cnt]=head[u],a[cnt]=x,b[cnt]=y,head[u]=cnt;
    ++cnt,to[cnt]=u,nxt[cnt]=head[v],a[cnt]=x,b[cnt]=y,head[v]=cnt;
}
ll dp[maxn][22];
//dp[i][j] 表示已i為根節點的子樹,選擇了j條a邊,最遠距離最小。
int n,k;
int siz[maxn];
ll tmp[22];
//tmp[i] 表示選了i條a邊的最遠距離
void dfs(int u,int pre,ll x){
    dp[u][0]=siz[u]=0;
    for(int i=head[u];i;i=nxt[i]){
        int v = to[i];
        if(v == pre) continue;
        dfs(v,u,x);
        int num = min(k,siz[u]+siz[v]+1);
        for(int j=0;j<=num;j++) tmp[j]=x+1;
        for(int j=0;j<=siz[u];j++){
            for(int h=0;h<=siz[v]&&h+j<=k;h++){
                if(dp[u][j]+dp[v][h]+a[i]<=x){
                    tmp[j+h+1]=min(tmp[j+h+1],max(dp[u][j],dp[v][h]+a[i]));
                }
                if(dp[u][j]+dp[v][h]+b[i]<=x){
                    tmp[j+h]=min(tmp[j+h],max(dp[u][j],dp[v][h]+b[i]));
                }
            }
        }
        siz[u]=num;
        for(int j=0;j<=siz[u];j++) dp[u][j]=tmp[j];
    }
}

bool check(ll x){
    dfs(1,0,x);
    if(dp[1][k]<=x) return true;
    return false;
}
void init(int n){
    cnt=0;
    for(int i=0;i<=n;i++) head[i]=0;
}
int main(){
    int t;
    scanf("%d",&t);
    while(t--){
        scanf("%d%d",&n,&k);
        init(n);
        for(int i=1;i<n;i++){
            int u,v,x,y;
            scanf("%d%d%d%d",&u,&v,&x,&y);
            add(u,v,x,y);
        }
        ll l=1,r=inf64,ans=0;
        while(l<=r){
            ll mid=(l+r)>>1ll;
            if(check(mid)) ans=mid,r=mid-1;
            else l=mid+1;
        }
        printf("%lld\n",ans);
    }
}