LibreOJ #2478.「九省聯考 2018」林克卡特樹 樹形dp+帶權二分
阿新 • • 發佈:2019-02-15
題意
給出一棵n個節點的樹和k,邊有邊權,要求先從樹中選k條邊,然後把這k條邊刪掉,再加入k條邊權為0的邊,滿足操作完後的圖仍然是一棵樹。問新樹的帶權直徑最大是多少。
分析
不難發現我們要求的就是在樹中選出k+1條不相交的鏈使得其權值和最大。
當k比較小的時候,我們可以樹形dp,設表示以i為根的樹中選了j條鏈,節點i的度數為時的最大權值。
這樣做的複雜度是的。
當k變大之後,我們就可以用帶權二分來做。
具體來說就是二分一個權值,然後給每條路徑的權值加上,再去除k的限制後進行dp。若選出的鏈的數量不小於k+1,則把mid減小,否則把mid增大。
當某一刻選出的鏈數量恰好為k+1,則當前權值-mid*(k+1)就是答案了。
程式碼
#include<iostream>
#include<cstdio>
#include<cstdlib>
#include<cstring>
#include<algorithm>
typedef long long LL;
const int N=300005;
const LL inf=(LL)1e12;
int n,k,cnt,last[N];
LL mid;
struct data
{
LL x,y;
bool operator > (const data &d) const {return x>d.x||x==d.x&&y>d.y;}
bool operator < (const data &d) const {return x<d.x||x==d.x&&y<d.y;}
data operator + (const data &d) const {return (data){x+d.x,y+d.y};}
}f[N][3],tmp[3];
struct edge{int to,next,w;}e[N*2];
int read()
{
int x=0,f=1;char ch=getchar();
while (ch<'0'||ch>'9'){if(ch=='-')f=-1;ch=getchar();}
while (ch>='0'&&ch<='9'){x=x*10+ch-'0';ch=getchar();}
return x*f;
}
void addedge(int u,int v,int w)
{
e[++cnt].to=v;e[cnt].w=w;e[cnt].next=last[u];last[u]=cnt;
e[++cnt].to=u;e[cnt].w=w;e[cnt].next=last[v];last[v]=cnt;
}
void dp(int x,int fa)
{
f[x][0]=(data){0,0};f[x][1]=(data){-inf,0};f[x][2]=(data){mid,1};
for (int i=last[x];i;i=e[i].next)
{
int to=e[i].to;
if (to==fa) continue;
dp(to,x);
data u;tmp[0]=tmp[1]=tmp[2]=(data){-inf,0};
u=f[x][0]+f[to][0];tmp[0]=std::max(tmp[0],u);
u.x+=(LL)e[i].w+mid;u.y++;tmp[1]=std::max(tmp[1],u);
u=f[x][0]+f[to][1];tmp[0]=std::max(tmp[0],u);
u.x+=(LL)e[i].w;tmp[1]=std::max(tmp[1],u);
u=f[x][0]+f[to][2];tmp[0]=std::max(tmp[0],u);
u=f[x][1]+f[to][0];tmp[1]=std::max(tmp[1],u);
u.x+=(LL)e[i].w;tmp[2]=std::max(tmp[1],u);
u=f[x][1]+f[to][1];tmp[1]=std::max(tmp[1],u);
u.x+=e[i].w-mid;u.y--;tmp[2]=std::max(tmp[2],u);
u=f[x][1]+f[to][2];tmp[1]=std::max(tmp[1],u);
u=f[x][2]+f[to][0];tmp[2]=std::max(tmp[2],u);
u=f[x][2]+f[to][1];tmp[2]=std::max(tmp[2],u);
u=f[x][2]+f[to][2];tmp[2]=std::max(tmp[2],u);
f[x][0]=tmp[0];f[x][1]=tmp[1];f[x][2]=tmp[2];
}
}
int main()
{
n=read();k=read();
for (int i=1;i<n;i++)
{
int x=read(),y=read(),z=read();
addedge(x,y,z);
}
LL l=-inf,r=inf,ans;
while (l<=r)
{
mid=(l+r)/2;
dp(1,0);
data u=std::max(f[1][0],std::max(f[1][1],f[1][2]));
if (u.y>=k+1) ans=u.x,r=mid-1;
else l=mid+1;
}
printf("%lld",ans-(LL)(r+1)*(k+1));
return 0;
}