「CTSC2018」暴力寫掛
給定兩棵樹,求 \(\text{depth}(x)+\text{depth}(y)-(\text{depth}(\text{lca}(x,y))+\text{depth}'(\text{lca}'(x,y)))\) 的最大值,\(n\leqslant 366666\)。
這裡有一個叫做邊分樹合併的東西,我們考慮邊分治過程中,把每次分治的邊看作點,分治過程中相鄰兩層的重邊連邊,當連通塊大小為 \(1\) 時,與點連邊,這樣會形成一棵二叉樹,其中葉子都是原樹中的點,非葉子是原樹中的邊。然後我們一開始把邊分樹拆成根節點到每個葉子節點的鏈,並按照某種順序合併,注意到邊分樹是二叉樹,所以我們可以使用和線段樹合併一樣的方法合併。並且對於每一條路徑 \((x,y)\)
我們把答案改寫成 \(\dfrac{1}{2}(\text{dis}(x,y)+\text{depth}(x)+\text{depth}(y)-2\text{depth}'(\text{lca}'(x,y)))\) 的形式,我們考慮對第一棵樹進行邊分治,並建出每個點從根到這個點的邊分樹(\(n\) 條鏈),我們在邊分樹上每個非葉子結點維護以下資訊:分別來自左、右子樹所包含的原樹中的點中 \(\text{dis}(x)+\text{depth(x)}\) 的最大值 \(vl,vr\)
接著我們對第二棵樹進行 dfs,我們列舉 \(\text{lca}'(x,y)\),每遇到一條邊,就把兩個端點的邊分樹按照線段樹的合併方法合併起來,然後我們在合併過程中統計答案:假設這兩棵邊分樹同時包括某一個點,這個點在兩棵邊分樹的編號分別是 \(u,v\),那麼我們可以用 \(\max(vl_u+vr_v,vr_u+vl_v)-2\text{depth}'(\text{lca}'(x,y))\) 來更新全域性答案,並更新合併後的點的 \(vl,vr\),不難發現這樣統計的點對 \((x,y)\)
可能有些細節可以看程式碼(懶得開 long long 所以不做貓了):
#include<bits/stdc++.h>
#define int long long
using namespace std;
int const N=366671;
template<unsigned M>struct graph
{
int target[2*M],pre[2*M],last[M],tot,w[2*M];
void add(int x,int y,int z)
{
target[++tot]=y;
pre[tot]=last[x];
last[x]=tot;
w[tot]=z;
}
};
graph<N>g1,g3;graph<N<<1>g2;
int n,now,siz[N<<1],nows,ed,li,cnt,mx[N<<5|1][2],ch[N<<5|1][2],last[N],dis[N],
tmp,ans=-1e18,rt[N];
bool del[N<<2];
void dfs(int x,int fa)
{
int las=x;
for(int i=g1.last[x];i;i=g1.pre[i])
{
int tar=g1.target[i];
if(tar==fa)continue;
g2.add(las,++now,0),g2.add(now,las,0),g2.add(now,tar,g1.w[i]),g2.add(tar,now,g1.w[i]);
dis[tar]=dis[x]+g1.w[i];las=now;dfs(tar,x);
}
}
void dfs2(int x,int fa,int nowd,int op)
{
if(!op)tmp++;
if(x<=n)
{
cnt++;
if(!last[x])rt[x]=last[x]=cnt,cnt++;
ch[last[x]][op]=cnt;
mx[last[x]][op]=dis[x]+nowd;
last[x]=cnt;
}
for(int i=g2.last[x];i;i=g2.pre[i])
{
int tar=g2.target[i];
if(tar==fa||del[i])continue;
dfs2(tar,x,nowd+g2.w[i],op);
}
}
void get(int x,int fa)
{
siz[x]=1;
for(int i=g2.last[x];i;i=g2.pre[i])
{
int tar=g2.target[i];
if(tar==fa||del[i])continue;
get(tar,x);siz[x]+=siz[tar];
if(max(siz[tar],nows-siz[tar])<li)li=max(siz[tar],nows-siz[tar]),ed=(i+1)>>1;
}
}
void solve(int x,int s)
{
if(s==1)return;
ed=li=1e9;nows=s;
get(x,0);
int r1=g2.target[(ed<<1)-1],r2=g2.target[ed<<1];
del[(ed<<1)-1]=del[ed<<1]=1;tmp=0;
dfs2(r1,r2,0,0),dfs2(r2,r1,g2.w[ed<<1],1);
int tt=siz[x]-tmp;
solve(r1,tmp);solve(r2,tt);
}
int merge(int x,int y,int t)
{
if((!x)||(!y))return x+y;
ans=max(ans,max(mx[x][0]+mx[y][1],mx[y][0]+mx[x][1])+2*t);
mx[x][0]=max(mx[x][0],mx[y][0]),mx[x][1]=max(mx[x][1],mx[y][1]);
ch[x][0]=merge(ch[x][0],ch[y][0],t),ch[x][1]=merge(ch[x][1],ch[y][1],t);
return x;
}
void dfs3(int x,int fa,int nowd)
{
ans=max(ans,2*(dis[x]-nowd));
for(int i=g3.last[x];i;i=g3.pre[i])
{
int tar=g3.target[i];
if(tar==fa)continue;
dfs3(tar,x,nowd+g3.w[i]);
rt[x]=merge(rt[x],rt[tar],-nowd);
}
}
signed main()
{
memset(mx,0xc0,sizeof(mx));
int x,y,z;
scanf("%lld",&n);now=n;
for(int i=1;i<n;i++)scanf("%lld%lld%lld",&x,&y,&z),g1.add(x,y,z),g1.add(y,x,z);
dfs(1,0);
solve(1,now);
for(int i=1;i<n;i++)scanf("%lld%lld%lld",&x,&y,&z),g3.add(x,y,z),g3.add(y,x,z);
dfs3(1,0,0);
printf("%lld",ans>>1);
return 0;
}