1. 程式人生 > 實用技巧 >HDU6820 (2020杭電多校第5場1007)

HDU6820 (2020杭電多校第5場1007)

題意

有一個n個點構成的樹,每條邊有一個邊權d。求最多有一個點度數超過k的聯通子圖的邊權和最大值。

分析

  1. 首先k=0時答案為0

  2. dp[0][u]代表以u為根的子樹中所有點度數都小於等於k時的邊權和最大值,且u與它的父節點有連邊。dp[1][u]代表以u為根的子樹中存在一個點的度數大於k時的邊權和最大值,且u與它的父節點有連邊。

  3. \(dp[0][u]=max_{v_1,v_2,...,v_{k-1}}(\sum_{i=v_1,v_2,...,v_{k-1}}dp[0][i])+d\),其中vu的兒子節點,du的父親節點到u的邊權值

  4. \(dp[1][u]=max(\sum_vdp[0][v],max_{v_1,v_2,...,v_{k-1}}(\sum_{i=v_1,v_2,...,v_{k-2}}dp[0][i]+dp[1][v_{k-1}]))+d\)

  5. \(ans=max_u(dp[1][u],max_{v_1,v_2,...,v_{k}}(\sum_{i=v_1,v_2,...,v_{k-1}}dp[0][i]+dp[1][v_k]))\)

  6. 其中\(max_{v_1,v_2,...,v_{k}}(\sum_{i=v_1,v_2,...,v_{k-1}}dp[0][i]+dp[1][v_k])\)可以對dp[0][v]由大到小排序,然後對前k個計算\(\sum_{i=1}^kdp[1][i]+dp[1][v]-dp[0][v]\),對後cnt-k個計算\(\sum_{i=1}^kdp[1][i]+dp[1][k]-dp[0][v]\),複雜度為O(nlogn)

程式碼

#include<bits/stdc++.h>
using namespace std;
const int maxn=2e5+5;
typedef long long ll;
int n,k;

struct Node{int to,next;ll d;}edge[maxn*2];
int head[maxn],ecnt;
int cnt[maxn];
ll Ans,ans[2][maxn];
int son[maxn];
void init()
{
    memset(head,-1,sizeof(head[0])*(n+5));
    memset(cnt,0,sizeof(cnt[0])*(n+5));
    ecnt=0;
    Ans=0;
}
void addedge(int u,int v,ll d)
{
    edge[ecnt]={v,head[u],d};
    head[u]=ecnt++;
    edge[ecnt]={u,head[v],d};
    head[v]=ecnt++;
    cnt[u]++;cnt[v]++;
}

void dfs(int u,int fa,ll d)
{
    ans[1][u]=ans[0][u]=d;
    vector<ll>v0;
    for(int i=head[u];i!=-1;i=edge[i].next)
    {
        int v=edge[i].to;
        if(v==fa)continue;
        dfs(v,u,edge[i].d);

        ans[1][u]+=ans[0][v];
        v0.push_back(ans[0][v]);
    }

    sort(v0.begin(),v0.end(),greater<ll>());
    for(int i=0;i<min((ll)v0.size(),(ll)k-1);i++)
        ans[0][u]+=v0[i];

    for(int i=head[u],j=1;i!=-1;i=edge[i].next)
        if(edge[i].to!=fa)
            son[j]=edge[i].to,j++;
    sort(son+1,son+cnt[u]+1,[](int i,int j){
        return ans[0][i]>ans[0][j];
    });
    //return
    int nn=min(k-1,cnt[u]);
    ll ans1=d;
    if(k>=2)
    {
        for(int i=1;i<=nn;i++)
            ans1+=ans[0][son[i]];
        for(int i=1;i<=nn;i++)
            ans[1][u]=max(ans[1][u],ans1-ans[0][son[i]]+ans[1][son[i]]);
        ans1-=ans[0][son[nn]];
        for(int i=nn+1;i<=cnt[u];i++)
            ans[1][u]=max(ans[1][u],ans1+ans[1][son[i]]);
    }
    else
        ans[1][u]=max(ans[1][u],ans1);
    //dp
    nn=min(k,cnt[u]);
    ll ans2=0;
    for(int i=1;i<=nn;i++)
        ans2+=ans[0][son[i]];
    for(int i=1;i<=nn;i++)
        Ans=max(Ans,ans2-ans[0][son[i]]+ans[1][son[i]]);
    ans2-=ans[0][son[nn]];
    for(int i=nn+1;i<=cnt[u];i++)
        Ans=max(Ans,ans2+ans[1][son[i]]);
    
    Ans=max(Ans,ans[1][u]);
}

int main()
{
    int t;
    scanf("%d",&t);
    while(t--)
    {
        int u,v;ll d;
        scanf("%d%d",&n,&k);
        init();
        for(int i=1;i<=n-1;i++)
        {
            scanf("%d%d%lld",&u,&v,&d);
            addedge(u,v,d);
        }
        if(k==0)
        {
            printf("0\n");
            continue;
        }
        for(int i=2;i<=n;i++)
            cnt[i]--;
        dfs(1,0,0);
        printf("%lld\n",Ans);
    }
    return 0;
}