1. 程式人生 > >bzoj 4987: Tree 樹形

bzoj 4987: Tree 樹形

Description
從前有棵樹。
找出kk個點A1A2AkA_1,A_2,…,A_k
使得dis(Ai,Ai+1),(1<=i<=k1)∑dis(A_i,Ai+1),(1<=i<=k-1)最小。

Input
第一行兩個正整數n,kn,k,表示數的頂點數和需要選出的點個數。
接下來 n1n-1 行每行 33 個非負整數x,y,zx,y,z,表示從存在一條從 xxyy 權值為 zz 的邊。
1<

=k<=n1<=k<=n
1<x,y<=n1<x,y<=n
1<=z<=1051<=z<=10^5
n<=3000n <= 3000
Output

一行一個整數,表示最小的距離和。
Sample Input
10 7
1 2 35129
2 3 42976
3 4 24497
2 5 83165
1 6 4748
5 7 38311
4 8 70052
3 9 3561
8 10 80238

Sample Output
184524

分析:
一個很顯然的結論,所選的點必定是一個連通塊。
因為從 xx 走到 yy ,沿途的所有點我們選了並且加在他們中間,答案不會改變。
怎樣統計連通塊的答案呢?
這個其實就是每條邊走兩次減去一條最長鏈。
假如我們把最長鏈上的點當做關鍵點。我們設f[i][j][0/1/2]f[i][j][0/1/2]表示在ii的子樹中,選了jj個點,選了0/1/20/1/2個關鍵點的答案。
考慮一棵子樹的答案與前面子樹如何合併。顯然,只有當這棵子樹內只有一個關鍵點時,最長鏈必定經過當前點連向該子樹的邊,此時這條邊只要算一倍答案,否則要算兩倍的答案。
聽說size大小合併是O

(n2)O(n^2)的,一開始以為是O(n3)O(n^3)一直以為不會做。

程式碼:

/**************************************************************
    Problem: 4987
    User: ypxrain
    Language: C++
    Result: Accepted
    Time:2928 ms
    Memory:107512 kb
****************************************************************/
 
#include <iostream>
#include <cstdio>
#include <cmath>
 
const int maxn=3007;
const int inf=0x3f3f3f3f;
 
using namespace std;
 
int n,m,x,y,w,cnt,ans;
int ls[maxn],f[maxn][maxn][3],tmp[maxn][3],size[maxn];
 
struct edge{
    int y,w,next;
}g[maxn*2];
 
void add(int x,int y,int w)
{
    g[++cnt]=(edge){y,w,ls[x]};
    ls[x]=cnt;
}
 
void dfs(int x,int fa)
{
    size[x]=1;
    for (int i=0;i<=n;i++)
    {
        for (int j=0;j<=2;j++) f[x][i][j]=inf;
    }
    f[x][1][0]=f[x][1][1]=f[x][1][2]=0;
    for (int i=ls[x];i>0;i=g[i].next)
    {
        int y=g[i].y;
        if (y==fa) continue;
        dfs(y,x);
        for (int j=0;j<=size[x]+size[y];j++)
        {
            for (int k=0;k<=2;k++) tmp[j][k]=inf;
        }
        for (int j=0;j<=size[x];j++)
        {
            for (int k=0;k<=size[y];k++)
            {
                for (int a=0;a<=2;a++)
                {
                    for (int b=0;a+b<=2;b++)
                    {
                        int len;
                        if (b==1) len=g[i].w;
                             else len=g[i].w*2;
                        tmp[j+k][a+b]=min(tmp[j+k][a+b],f[x][j][a]+f[y][k][b]+len);
                    }
                }
            }
        }
        size[x]+=size[y];
        for (int j=0;j<=size[x];j++)
        {
            for (int k=0;k<=2;k++) f[x][j][k]=min(f[x][j][k],tmp[j][k]);
        }
    }
    ans=min(ans,f[x][m][2]);
}
 
int main()
{
    scanf("%d%d",&n,&m);
    for (int i=1;i<n;i++)
    {
        scanf("%d%d%d",&x,&y,&w);
        add(x,y,w);
        add(y,x,w);
    }
    ans=inf;
    dfs(1,0);
    printf("%d",ans);
}