1. 程式人生 > 實用技巧 >「CTSC2018」暴力寫掛

「CTSC2018」暴力寫掛

給定兩棵樹,求 \(\text{depth}(x)+\text{depth}(y)-(\text{depth}(\text{lca}(x,y))+\text{depth}'(\text{lca}'(x,y)))\) 的最大值,\(n\leqslant 366666\)

這裡有一個叫做邊分樹合併的東西,我們考慮邊分治過程中,把每次分治的邊看作點,分治過程中相鄰兩層的重邊連邊,當連通塊大小為 \(1\) 時,與點連邊,這樣會形成一棵二叉樹,其中葉子都是原樹中的點,非葉子是原樹中的邊。然後我們一開始把邊分樹拆成根節點到每個葉子節點的鏈,並按照某種順序合併,注意到邊分樹是二叉樹,所以我們可以使用和線段樹合併一樣的方法合併。並且對於每一條路徑 \((x,y)\)

的貢獻,它都可以在 \(x\)\(y\) 第一次被合併時在邊分樹中 \(x,y\) 的 lca 處統計到。

我們把答案改寫成 \(\dfrac{1}{2}(\text{dis}(x,y)+\text{depth}(x)+\text{depth}(y)-2\text{depth}'(\text{lca}'(x,y)))\) 的形式,我們考慮對第一棵樹進行邊分治,並建出每個點從根到這個點的邊分樹(\(n\) 條鏈),我們在邊分樹上每個非葉子結點維護以下資訊:分別來自左、右子樹所包含的原樹中的點中 \(\text{dis}(x)+\text{depth(x)}\) 的最大值 \(vl,vr\)

,其中 \(\text{dis}(x)\)\(x\) 到原樹中這個點對應的邊某一個端點的距離。

接著我們對第二棵樹進行 dfs,我們列舉 \(\text{lca}'(x,y)\),每遇到一條邊,就把兩個端點的邊分樹按照線段樹的合併方法合併起來,然後我們在合併過程中統計答案:假設這兩棵邊分樹同時包括某一個點,這個點在兩棵邊分樹的編號分別是 \(u,v\),那麼我們可以用 \(\max(vl_u+vr_v,vr_u+vl_v)-2\text{depth}'(\text{lca}'(x,y))\) 來更新全域性答案,並更新合併後的點的 \(vl,vr\),不難發現這樣統計的點對 \((x,y)\)

的 lca 一定都是我們當前列舉到的點。注意特判 \(x=y\) 的情況和邊分治前記得三度化。這樣做的複雜度是 \(O(n\log n)\) 的(邊分治和邊分樹合併都是 1 個 log 的)。

可能有些細節可以看程式碼(懶得開 long long 所以不做貓了):

#include<bits/stdc++.h>
#define int long long
using namespace std;
int const N=366671;
template<unsigned M>struct graph
{
    int target[2*M],pre[2*M],last[M],tot,w[2*M];
    void add(int x,int y,int z)
    {
        target[++tot]=y;
        pre[tot]=last[x];
        last[x]=tot;
        w[tot]=z;
    }
};
graph<N>g1,g3;graph<N<<1>g2;
int n,now,siz[N<<1],nows,ed,li,cnt,mx[N<<5|1][2],ch[N<<5|1][2],last[N],dis[N],
tmp,ans=-1e18,rt[N];
bool del[N<<2];
void dfs(int x,int fa)
{
    int las=x;
    for(int i=g1.last[x];i;i=g1.pre[i])
    {
        int tar=g1.target[i];
        if(tar==fa)continue;
        g2.add(las,++now,0),g2.add(now,las,0),g2.add(now,tar,g1.w[i]),g2.add(tar,now,g1.w[i]);
        dis[tar]=dis[x]+g1.w[i];las=now;dfs(tar,x);
    }
}
void dfs2(int x,int fa,int nowd,int op)
{
    if(!op)tmp++;
    if(x<=n)
    {
        cnt++;
        if(!last[x])rt[x]=last[x]=cnt,cnt++;
        ch[last[x]][op]=cnt;
        mx[last[x]][op]=dis[x]+nowd;
        last[x]=cnt;
    }
    for(int i=g2.last[x];i;i=g2.pre[i])
    {
        int tar=g2.target[i];
        if(tar==fa||del[i])continue;
        dfs2(tar,x,nowd+g2.w[i],op);
    }
}
void get(int x,int fa)
{
    siz[x]=1;
    for(int i=g2.last[x];i;i=g2.pre[i])
    {
        int tar=g2.target[i];
        if(tar==fa||del[i])continue;
        get(tar,x);siz[x]+=siz[tar];
        if(max(siz[tar],nows-siz[tar])<li)li=max(siz[tar],nows-siz[tar]),ed=(i+1)>>1;
    }
}
void solve(int x,int s)
{
    if(s==1)return;
    ed=li=1e9;nows=s;
    get(x,0);
    int r1=g2.target[(ed<<1)-1],r2=g2.target[ed<<1];
    del[(ed<<1)-1]=del[ed<<1]=1;tmp=0;
    dfs2(r1,r2,0,0),dfs2(r2,r1,g2.w[ed<<1],1);
    int tt=siz[x]-tmp;
    solve(r1,tmp);solve(r2,tt);
}
int merge(int x,int y,int t)
{
    if((!x)||(!y))return x+y;
    ans=max(ans,max(mx[x][0]+mx[y][1],mx[y][0]+mx[x][1])+2*t);
    mx[x][0]=max(mx[x][0],mx[y][0]),mx[x][1]=max(mx[x][1],mx[y][1]);
    ch[x][0]=merge(ch[x][0],ch[y][0],t),ch[x][1]=merge(ch[x][1],ch[y][1],t);
    return x;
}
void dfs3(int x,int fa,int nowd)
{
    ans=max(ans,2*(dis[x]-nowd));
    for(int i=g3.last[x];i;i=g3.pre[i])
    {
        int tar=g3.target[i];
        if(tar==fa)continue;
        dfs3(tar,x,nowd+g3.w[i]);
        rt[x]=merge(rt[x],rt[tar],-nowd);
    }
}
signed main()
{
    memset(mx,0xc0,sizeof(mx));
    int x,y,z;
    scanf("%lld",&n);now=n;
    for(int i=1;i<n;i++)scanf("%lld%lld%lld",&x,&y,&z),g1.add(x,y,z),g1.add(y,x,z);
    dfs(1,0);
    solve(1,now);
    for(int i=1;i<n;i++)scanf("%lld%lld%lld",&x,&y,&z),g3.add(x,y,z),g3.add(y,x,z);
    dfs3(1,0,0);
    printf("%lld",ans>>1);
    return 0;
}