1. 程式人生 > >[學習筆記]虛樹

[學習筆記]虛樹

模板:(樹剖\(LCA\)+建虛樹)

#include <bits/stdc++.h>
using namespace std;
const int maxn=100000+10;
int n,m,dp[maxn],vis[maxn],h[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim;

struct node{
    int to,next;
}e[maxn<<1];

inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}
inline void add(int x,int y){
    e[++tot].to=y;
    e[tot].next=head[x];
    head[x]=tot;
}
inline void addedge(int x,int y){
    to[++cnt]=y;
    nxt[cnt]=fir[x];
    fir[x]=cnt;
}

void dfs1(int x,int f){
    siz[x]=1;fa[x]=f;
    dep[x]=dep[f]+1;
    int maxson=-1;
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==f) continue;
        dfs1(y,x);
        siz[x]+=siz[y];
        if(maxson<siz[y]){
            maxson=siz[y];
            son[x]=y;
        }
    }
}

void dfs2(int x,int topf){
    id[x]=++tim;
    top[x]=topf;
    if(son[x]) dfs2(son[x],topf);
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}

int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}

bool cmp(int a,int b){
    return id[a]<id[b];
}

int main()
{
    n=read();
    int x,y,w,k,lca;
    for(int i=1;i<n;i++){
        x=read(),y=read();
        add(x,y);add(y,x);
    }
    dfs1(1,0);dfs2(1,1);
    m=read();
    for(int t=1;t<=m;t++){
        k=read();
        for(int i=1;i<=k;i++){
            h[i]=read();
            vis[h[i]]=1;
        }
        sort(h+1,h+k+1,cmp);
        cnt=0;sta[Top=1]=1;
        for(int i=1;i<=k;i++){
            lca=LCA(sta[Top],h[i]);
            while(dep[lca]<dep[sta[Top]]){
                if(dep[sta[Top-1]]<=dep[lca]){
                    addedge(lca,sta[Top--]);
                    if(sta[Top]!=lca) sta[++Top]=lca;
                    break;
                }
                addedge(sta[Top-1],sta[Top]);
                Top--;
            }
            if(sta[Top]!=h[i]) sta[++Top]=h[i];
        }
        while(--Top) addedge(sta[Top],sta[Top+1]);
        for(int i=1;i<=k;i++) vis[h[i]]=0;
    }
    return 0;
}

具體建虛樹怎麼建可以看別人的部落格……我講的肯定沒有它們好

1、[SDOI2011]消耗戰

分析:人生第一道虛樹題。

難的就是建一棵虛樹,然後在虛樹上樹形 \(dp\)

首先,打出一個樹上字首最小值。因為無論怎樣,選一條最小的邊斷掉一定是最優的。建一棵虛樹,若遍歷到選定的點 \(x\),那麼 \(dp[x]=min(dis[x],\sum_{son\in x}val_{x->son})\),其中 \(val\) 為邊權。

\(Code\ Below:\)

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=250000+10;
const int inf=1e18;
int n,m,dp[maxn],dis[maxn],vis[maxn],h[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim;

struct node{
    int to,next,val;
}e[maxn<<1];

inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}
inline void add(int x,int y,int w){
    e[++tot].to=y;
    e[tot].val=w;
    e[tot].next=head[x];
    head[x]=tot;
}
inline void addedge(int x,int y){
    to[++cnt]=y;
    nxt[cnt]=fir[x];
    fir[x]=cnt;
}

void dfs1(int x,int f){
    siz[x]=1;fa[x]=f;
    dep[x]=dep[f]+1;
    int maxson=-1;
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==f) continue;
        dis[y]=min(dis[x],e[i].val);
        dfs1(y,x);
        siz[x]+=siz[y];
        if(maxson<siz[y]){
            maxson=siz[y];
            son[x]=y;
        }
    }
}

void dfs2(int x,int topf){
    id[x]=++tim;
    top[x]=topf;
    if(son[x]) dfs2(son[x],topf);
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}

int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}

bool cmp(int a,int b){
    return id[a]<id[b];
}

void dfs(int x,int flag){
    dp[x]=dis[x];
    if(flag){
        for(int i=fir[x];i;i=nxt[i])
            dfs(to[i],flag);
        fir[x]=vis[x]=0;
        return ;
    }
    int val=0;
    for(int i=fir[x],y;i;i=nxt[i]){
        y=to[i];
        dfs(y,vis[y]);
        val+=dp[y];
    }
    if(!fir[x]||vis[x]) val=inf;
    dp[x]=min(dp[x],val);
    fir[x]=vis[x]=0;
}

signed main()
{
    n=read();
    int x,y,w,k,lca;
    for(int i=1;i<n;i++){
        x=read(),y=read(),w=read();
        add(x,y,w);add(y,x,w);
    }
    dis[1]=inf;
    dfs1(1,0);dfs2(1,1);
    m=read();
    for(int t=1;t<=m;t++){
        k=read();
        for(int i=1;i<=k;i++){
            h[i]=read();
            vis[h[i]]=1;
        }
        sort(h+1,h+k+1,cmp);
        cnt=0;
        sta[Top=1]=1;
        for(int i=1;i<=k;i++){
            lca=LCA(sta[Top],h[i]);
            while(dep[lca]<dep[sta[Top]]){
                if(dep[sta[Top-1]]<=dep[lca]){
                    addedge(lca,sta[Top]);
                    if(lca!=sta[--Top]) sta[++Top]=lca;
                    break;
                }
                addedge(sta[Top-1],sta[Top]);
                Top--;
            }
            if(sta[Top]!=h[i]) sta[++Top]=h[i];
        }
        while(--Top) addedge(sta[Top],sta[Top+1]);
        dfs(1,0);
        printf("%lld\n",dp[1]);
    }
    return 0;
}

2、[HEOI2014]大工程

分析:這道題自己推的,很有成就感哈哈哈

方法與上題一樣,不過多一點細節

\(sub[x]\) 表示在虛樹上 \(x\) 的子樹內有多少個選定點

這些點對 \((x,y)\) 對答案的貢獻要分兩類討論:

1、\(x=lca(x,y)\) ,那麼直接在 \(vis[x]=1\) 的時候算掉

2、\(x,y\) 在兩棵不同的子樹內,那就一邊更新 \(sub[x]\) 一邊算

void dfs(int x){
    int now=0;
    for(int i=fir[x];i;i=nxt[i]){
        dfs(to[i]);
        now+=sub[x]*sub[to[i]];
        sub[x]+=sub[to[i]];
        sub[to[i]]=0;
    }
    ans-=2*now*dep[x];
    if(vis[x]) ans-=2*sub[x]*dep[x],sub[x]++;
}

找最小邊權就是記錄一下最小值和次小值,然後更新 \(ans\)

找最大邊權同個道理

int dfs_min(int x){
    int Min=inf,sec=inf;
    for(int i=fir[x];i;i=nxt[i]){
        sec=min(sec,dfs_min(to[i]));
        if(Min>sec) swap(Min,sec);
    }
    if(vis[x]&&Min!=inf) ans=min(ans,Min-dep[x]);
    if(Min!=inf&&sec!=inf) ans=min(ans,Min+sec-2*dep[x]);
    if(vis[x]) Min=dep[x];
    return Min;
}

int dfs_max(int x){
    int Max=-inf,sec=-inf;
    for(int i=fir[x];i;i=nxt[i]){
        sec=max(sec,dfs_max(to[i]));
        if(Max<sec) swap(Max,sec);
    }
    if(vis[x]&&Max!=-inf) ans=max(ans,Max-dep[x]);
    if(Max!=-inf&&sec!=-inf) ans=max(ans,Max+sec-2*dep[x]);
    if(vis[x]&&Max==-inf) Max=dep[x];
    fir[x]=0;
    return Max;
}

那個前式鏈向星陣列 \(fir[x]\) 一定要在 \(dfsmax()\) 的時候清空!!!

\(Code\ Below:\)

#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=1000000+10;
const int inf=1e18;
int n,m,dp[maxn],vis[maxn],h[maxn],sub[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim,ans;

struct node{
    int to,next;
}e[maxn<<1];

inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}
inline void add(int x,int y){
    e[++tot].to=y;
    e[tot].next=head[x];
    head[x]=tot;
}
inline void addedge(int x,int y){
    to[++cnt]=y;
    nxt[cnt]=fir[x];
    fir[x]=cnt;
}

void dfs1(int x,int f){
    siz[x]=1;fa[x]=f;
    dep[x]=dep[f]+1;
    int maxson=-1;
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==f) continue;
        dfs1(y,x);
        siz[x]+=siz[y];
        if(maxson<siz[y]){
            maxson=siz[y];
            son[x]=y;
        }
    }
}

void dfs2(int x,int topf){
    id[x]=++tim;
    top[x]=topf;
    if(son[x]) dfs2(son[x],topf);
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}

int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}

bool cmp(int a,int b){
    return id[a]<id[b];
}

void dfs(int x){
    int now=0;
    for(int i=fir[x];i;i=nxt[i]){
        dfs(to[i]);
        now+=sub[x]*sub[to[i]];
        sub[x]+=sub[to[i]];
        sub[to[i]]=0;
    }
    ans-=2*now*dep[x];
    if(vis[x]) ans-=2*sub[x]*dep[x],sub[x]++;
    //printf("x=%lld,now=%lld,ans=%lld,sub[x]=%lld,dep[x]=%lld\n",x,now,ans,sub[x],dep[x]);
}

int dfs_min(int x){
    int Min=inf,sec=inf;
    for(int i=fir[x];i;i=nxt[i]){
        sec=min(sec,dfs_min(to[i]));
        if(Min>sec) swap(Min,sec);
    }
    if(vis[x]&&Min!=inf) ans=min(ans,Min-dep[x]);
    if(Min!=inf&&sec!=inf) ans=min(ans,Min+sec-2*dep[x]);
    if(vis[x]) Min=dep[x];
    return Min;
}

int dfs_max(int x){
    int Max=-inf,sec=-inf;
    for(int i=fir[x];i;i=nxt[i]){
        sec=max(sec,dfs_max(to[i]));
        if(Max<sec) swap(Max,sec);
    }
    if(vis[x]&&Max!=-inf) ans=max(ans,Max-dep[x]);
    if(Max!=-inf&&sec!=-inf) ans=max(ans,Max+sec-2*dep[x]);
    if(vis[x]&&Max==-inf) Max=dep[x];
    fir[x]=0;
    return Max;
}

signed main()
{
    n=read();
    int x,y,w,k,lca;
    for(int i=1;i<n;i++){
        x=read(),y=read();
        add(x,y);add(y,x);
    }
    dfs1(1,0);dfs2(1,1);
    m=read();
    for(int t=1;t<=m;t++){
        k=read();
        for(int i=1;i<=k;i++){
            h[i]=read();
            vis[h[i]]=1;
        }
        if(k==1){
            printf("0 0 0\n");
            continue;
        }
        sort(h+1,h+k+1,cmp);
        cnt=0;sta[Top=1]=1;
        for(int i=1;i<=k;i++){
            lca=LCA(sta[Top],h[i]);
            while(dep[lca]<dep[sta[Top]]){
                if(dep[sta[Top-1]]<=dep[lca]){
                    addedge(lca,sta[Top--]);
                    if(sta[Top]!=lca) sta[++Top]=lca;
                    break;
                }
                addedge(sta[Top-1],sta[Top]);
                Top--;
            }
            if(sta[Top]!=h[i]) sta[++Top]=h[i];
        }
        while(--Top) addedge(sta[Top],sta[Top+1]);
        ans=0;
        for(int i=1;i<=k;i++) ans+=(k-1)*dep[h[i]];
        dfs(1);sub[1]=0;
        printf("%lld ",ans);
        ans=inf;dfs_min(1);
        printf("%lld ",ans);
        ans=-inf;dfs_max(1);
        printf("%lld\n",ans);
        for(int i=1;i<=k;i++) vis[h[i]]=0;
    }
    return 0;
}

3、CF613D Kingdom and its Cities

分析:在虛樹上樹形 \(dp\) 的時候分三種情況:

\(P.S:sum\) 表示有多少個兒子已經選了

1、\(vis[x]=1\),那麼不能選 \(x\) 來斷掉兒子的退路,那麼 \(dp[x]=\sum_{son\in x} dp[son]\)

2、\(vis[x]=0,sum>1\),那就直接選 \(x\)\(x\) 的子樹已經被 \(x\) 封死了

3、\(vis[x]=0,sum\leq 1\),那就傳到 \(x\) 的父親上,讓 \(x\) 的父親解決好了

\(Code\ Below:\)

#include <bits/stdc++.h>
using namespace std;
const int maxn=100000+10;
int n,m,dp[maxn],vis[maxn],h[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim;

struct node{
    int to,next;
}e[maxn<<1];

inline int read(){
    register int x=0,f=1;char ch=getchar();
    while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
    while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
    return (f==1)?x:-x;
}
inline void add(int x,int y){
    e[++tot].to=y;
    e[tot].next=head[x];
    head[x]=tot;
}
inline void addedge(int x,int y){
    to[++cnt]=y;
    nxt[cnt]=fir[x];
    fir[x]=cnt;
}

void dfs1(int x,int f){
    siz[x]=1;fa[x]=f;
    dep[x]=dep[f]+1;
    int maxson=-1;
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==f) continue;
        dfs1(y,x);
        siz[x]+=siz[y];
        if(maxson<siz[y]){
            maxson=siz[y];
            son[x]=y;
        }
    }
}

void dfs2(int x,int topf){
    id[x]=++tim;
    top[x]=topf;
    if(son[x]) dfs2(son[x],topf);
    for(int i=head[x],y;i;i=e[i].next){
        y=e[i].to;
        if(y==fa[x]||y==son[x]) continue;
        dfs2(y,y);
    }
}

int LCA(int x,int y){
    while(top[x]!=top[y]){
        if(dep[top[x]]<dep[top[y]]) swap(x,y);
        x=fa[top[x]];
    }
    if(dep[x]>dep[y]) swap(x,y);
    return x;
}

bool cmp(int a,int b){
    return id[a]<id[b];
}

int dfs(int x){
    int ans=0,sum=0;
    for(int i=fir[x];i;i=nxt[i])
        ans+=dfs(to[i]),sum+=dp[to[i]];
    if(vis[x]) dp[x]=1,ans+=sum;
    else if(sum>1) dp[x]=0,ans++;
    else dp[x]=sum;
    fir[x]=0;
    return ans;
}

int main()
{
    n=read();
    int x,y,w,k,lca,flag;
    for(int i=1;i<n;i++){
        x=read(),y=read();
        add(x,y);add(y,x);
    }
    dfs1(1,0);dfs2(1,1);
    m=read();
    for(int t=1;t<=m;t++){
        k=read();
        for(int i=1;i<=k;i++){
            h[i]=read();
            vis[h[i]]=1;
        }
        flag=0;
        for(int i=1;i<=k;i++)
            flag|=vis[fa[h[i]]];
        if(flag){
            printf("-1\n");
            for(int i=1;i<=k;i++) vis[h[i]]=0;
            continue;
        }
        sort(h+1,h+k+1,cmp);
        cnt=0;sta[Top=1]=1;
        for(int i=1;i<=k;i++){
            lca=LCA(sta[Top],h[i]);
            while(dep[lca]<dep[sta[Top]]){
                if(dep[sta[Top-1]]<=dep[lca]){
                    addedge(lca,sta[Top--]);
                    if(sta[Top]!=lca) sta[++Top]=lca;
                    break;
                }
                addedge(sta[Top-1],sta[Top]);
                Top--;
            }
            if(sta[Top]!=h[i]) sta[++Top]=h[i];
        }
        while(--Top) addedge(sta[Top],sta[Top+1]);
        printf("%d\n",dfs(1));
        for(int i=1;i<=k;i++) vis[h[i]]=0;
    }
    return 0;
}