[學習筆記]虛樹
模板:(樹剖\(LCA\)+建虛樹)
#include <bits/stdc++.h> using namespace std; const int maxn=100000+10; int n,m,dp[maxn],vis[maxn],h[maxn],sta[maxn],Top; int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt; int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim; struct node{ int to,next; }e[maxn<<1]; inline int read(){ register int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();} return (f==1)?x:-x; } inline void add(int x,int y){ e[++tot].to=y; e[tot].next=head[x]; head[x]=tot; } inline void addedge(int x,int y){ to[++cnt]=y; nxt[cnt]=fir[x]; fir[x]=cnt; } void dfs1(int x,int f){ siz[x]=1;fa[x]=f; dep[x]=dep[f]+1; int maxson=-1; for(int i=head[x],y;i;i=e[i].next){ y=e[i].to; if(y==f) continue; dfs1(y,x); siz[x]+=siz[y]; if(maxson<siz[y]){ maxson=siz[y]; son[x]=y; } } } void dfs2(int x,int topf){ id[x]=++tim; top[x]=topf; if(son[x]) dfs2(son[x],topf); for(int i=head[x],y;i;i=e[i].next){ y=e[i].to; if(y==fa[x]||y==son[x]) continue; dfs2(y,y); } } int LCA(int x,int y){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); return x; } bool cmp(int a,int b){ return id[a]<id[b]; } int main() { n=read(); int x,y,w,k,lca; for(int i=1;i<n;i++){ x=read(),y=read(); add(x,y);add(y,x); } dfs1(1,0);dfs2(1,1); m=read(); for(int t=1;t<=m;t++){ k=read(); for(int i=1;i<=k;i++){ h[i]=read(); vis[h[i]]=1; } sort(h+1,h+k+1,cmp); cnt=0;sta[Top=1]=1; for(int i=1;i<=k;i++){ lca=LCA(sta[Top],h[i]); while(dep[lca]<dep[sta[Top]]){ if(dep[sta[Top-1]]<=dep[lca]){ addedge(lca,sta[Top--]); if(sta[Top]!=lca) sta[++Top]=lca; break; } addedge(sta[Top-1],sta[Top]); Top--; } if(sta[Top]!=h[i]) sta[++Top]=h[i]; } while(--Top) addedge(sta[Top],sta[Top+1]); for(int i=1;i<=k;i++) vis[h[i]]=0; } return 0; }
具體建虛樹怎麼建可以看別人的部落格……我講的肯定沒有它們好
1、[SDOI2011]消耗戰
分析:人生第一道虛樹題。
難的就是建一棵虛樹,然後在虛樹上樹形 \(dp\)
首先,打出一個樹上字首最小值。因為無論怎樣,選一條最小的邊斷掉一定是最優的。建一棵虛樹,若遍歷到選定的點 \(x\),那麼 \(dp[x]=min(dis[x],\sum_{son\in x}val_{x->son})\),其中 \(val\) 為邊權。
\(Code\ Below:\)
#include <bits/stdc++.h> #define int long long using namespace std; const int maxn=250000+10; const int inf=1e18; int n,m,dp[maxn],dis[maxn],vis[maxn],h[maxn],sta[maxn],Top; int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt; int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim; struct node{ int to,next,val; }e[maxn<<1]; inline int read(){ register int x=0,f=1;char ch=getchar(); while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();} while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();} return (f==1)?x:-x; } inline void add(int x,int y,int w){ e[++tot].to=y; e[tot].val=w; e[tot].next=head[x]; head[x]=tot; } inline void addedge(int x,int y){ to[++cnt]=y; nxt[cnt]=fir[x]; fir[x]=cnt; } void dfs1(int x,int f){ siz[x]=1;fa[x]=f; dep[x]=dep[f]+1; int maxson=-1; for(int i=head[x],y;i;i=e[i].next){ y=e[i].to; if(y==f) continue; dis[y]=min(dis[x],e[i].val); dfs1(y,x); siz[x]+=siz[y]; if(maxson<siz[y]){ maxson=siz[y]; son[x]=y; } } } void dfs2(int x,int topf){ id[x]=++tim; top[x]=topf; if(son[x]) dfs2(son[x],topf); for(int i=head[x],y;i;i=e[i].next){ y=e[i].to; if(y==fa[x]||y==son[x]) continue; dfs2(y,y); } } int LCA(int x,int y){ while(top[x]!=top[y]){ if(dep[top[x]]<dep[top[y]]) swap(x,y); x=fa[top[x]]; } if(dep[x]>dep[y]) swap(x,y); return x; } bool cmp(int a,int b){ return id[a]<id[b]; } void dfs(int x,int flag){ dp[x]=dis[x]; if(flag){ for(int i=fir[x];i;i=nxt[i]) dfs(to[i],flag); fir[x]=vis[x]=0; return ; } int val=0; for(int i=fir[x],y;i;i=nxt[i]){ y=to[i]; dfs(y,vis[y]); val+=dp[y]; } if(!fir[x]||vis[x]) val=inf; dp[x]=min(dp[x],val); fir[x]=vis[x]=0; } signed main() { n=read(); int x,y,w,k,lca; for(int i=1;i<n;i++){ x=read(),y=read(),w=read(); add(x,y,w);add(y,x,w); } dis[1]=inf; dfs1(1,0);dfs2(1,1); m=read(); for(int t=1;t<=m;t++){ k=read(); for(int i=1;i<=k;i++){ h[i]=read(); vis[h[i]]=1; } sort(h+1,h+k+1,cmp); cnt=0; sta[Top=1]=1; for(int i=1;i<=k;i++){ lca=LCA(sta[Top],h[i]); while(dep[lca]<dep[sta[Top]]){ if(dep[sta[Top-1]]<=dep[lca]){ addedge(lca,sta[Top]); if(lca!=sta[--Top]) sta[++Top]=lca; break; } addedge(sta[Top-1],sta[Top]); Top--; } if(sta[Top]!=h[i]) sta[++Top]=h[i]; } while(--Top) addedge(sta[Top],sta[Top+1]); dfs(1,0); printf("%lld\n",dp[1]); } return 0; }
2、[HEOI2014]大工程
分析:這道題自己推的,很有成就感哈哈哈
方法與上題一樣,不過多一點細節
\(sub[x]\) 表示在虛樹上 \(x\) 的子樹內有多少個選定點
這些點對 \((x,y)\) 對答案的貢獻要分兩類討論:
1、\(x=lca(x,y)\) ,那麼直接在 \(vis[x]=1\) 的時候算掉
2、\(x,y\) 在兩棵不同的子樹內,那就一邊更新 \(sub[x]\) 一邊算
void dfs(int x){ int now=0; for(int i=fir[x];i;i=nxt[i]){ dfs(to[i]); now+=sub[x]*sub[to[i]]; sub[x]+=sub[to[i]]; sub[to[i]]=0; } ans-=2*now*dep[x]; if(vis[x]) ans-=2*sub[x]*dep[x],sub[x]++; }
找最小邊權就是記錄一下最小值和次小值,然後更新 \(ans\)
找最大邊權同個道理
int dfs_min(int x){
int Min=inf,sec=inf;
for(int i=fir[x];i;i=nxt[i]){
sec=min(sec,dfs_min(to[i]));
if(Min>sec) swap(Min,sec);
}
if(vis[x]&&Min!=inf) ans=min(ans,Min-dep[x]);
if(Min!=inf&&sec!=inf) ans=min(ans,Min+sec-2*dep[x]);
if(vis[x]) Min=dep[x];
return Min;
}
int dfs_max(int x){
int Max=-inf,sec=-inf;
for(int i=fir[x];i;i=nxt[i]){
sec=max(sec,dfs_max(to[i]));
if(Max<sec) swap(Max,sec);
}
if(vis[x]&&Max!=-inf) ans=max(ans,Max-dep[x]);
if(Max!=-inf&&sec!=-inf) ans=max(ans,Max+sec-2*dep[x]);
if(vis[x]&&Max==-inf) Max=dep[x];
fir[x]=0;
return Max;
}
那個前式鏈向星陣列 \(fir[x]\) 一定要在 \(dfsmax()\) 的時候清空!!!
\(Code\ Below:\)
#include <bits/stdc++.h>
#define int long long
using namespace std;
const int maxn=1000000+10;
const int inf=1e18;
int n,m,dp[maxn],vis[maxn],h[maxn],sub[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim,ans;
struct node{
int to,next;
}e[maxn<<1];
inline int read(){
register int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return (f==1)?x:-x;
}
inline void add(int x,int y){
e[++tot].to=y;
e[tot].next=head[x];
head[x]=tot;
}
inline void addedge(int x,int y){
to[++cnt]=y;
nxt[cnt]=fir[x];
fir[x]=cnt;
}
void dfs1(int x,int f){
siz[x]=1;fa[x]=f;
dep[x]=dep[f]+1;
int maxson=-1;
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==f) continue;
dfs1(y,x);
siz[x]+=siz[y];
if(maxson<siz[y]){
maxson=siz[y];
son[x]=y;
}
}
}
void dfs2(int x,int topf){
id[x]=++tim;
top[x]=topf;
if(son[x]) dfs2(son[x],topf);
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
return x;
}
bool cmp(int a,int b){
return id[a]<id[b];
}
void dfs(int x){
int now=0;
for(int i=fir[x];i;i=nxt[i]){
dfs(to[i]);
now+=sub[x]*sub[to[i]];
sub[x]+=sub[to[i]];
sub[to[i]]=0;
}
ans-=2*now*dep[x];
if(vis[x]) ans-=2*sub[x]*dep[x],sub[x]++;
//printf("x=%lld,now=%lld,ans=%lld,sub[x]=%lld,dep[x]=%lld\n",x,now,ans,sub[x],dep[x]);
}
int dfs_min(int x){
int Min=inf,sec=inf;
for(int i=fir[x];i;i=nxt[i]){
sec=min(sec,dfs_min(to[i]));
if(Min>sec) swap(Min,sec);
}
if(vis[x]&&Min!=inf) ans=min(ans,Min-dep[x]);
if(Min!=inf&&sec!=inf) ans=min(ans,Min+sec-2*dep[x]);
if(vis[x]) Min=dep[x];
return Min;
}
int dfs_max(int x){
int Max=-inf,sec=-inf;
for(int i=fir[x];i;i=nxt[i]){
sec=max(sec,dfs_max(to[i]));
if(Max<sec) swap(Max,sec);
}
if(vis[x]&&Max!=-inf) ans=max(ans,Max-dep[x]);
if(Max!=-inf&&sec!=-inf) ans=max(ans,Max+sec-2*dep[x]);
if(vis[x]&&Max==-inf) Max=dep[x];
fir[x]=0;
return Max;
}
signed main()
{
n=read();
int x,y,w,k,lca;
for(int i=1;i<n;i++){
x=read(),y=read();
add(x,y);add(y,x);
}
dfs1(1,0);dfs2(1,1);
m=read();
for(int t=1;t<=m;t++){
k=read();
for(int i=1;i<=k;i++){
h[i]=read();
vis[h[i]]=1;
}
if(k==1){
printf("0 0 0\n");
continue;
}
sort(h+1,h+k+1,cmp);
cnt=0;sta[Top=1]=1;
for(int i=1;i<=k;i++){
lca=LCA(sta[Top],h[i]);
while(dep[lca]<dep[sta[Top]]){
if(dep[sta[Top-1]]<=dep[lca]){
addedge(lca,sta[Top--]);
if(sta[Top]!=lca) sta[++Top]=lca;
break;
}
addedge(sta[Top-1],sta[Top]);
Top--;
}
if(sta[Top]!=h[i]) sta[++Top]=h[i];
}
while(--Top) addedge(sta[Top],sta[Top+1]);
ans=0;
for(int i=1;i<=k;i++) ans+=(k-1)*dep[h[i]];
dfs(1);sub[1]=0;
printf("%lld ",ans);
ans=inf;dfs_min(1);
printf("%lld ",ans);
ans=-inf;dfs_max(1);
printf("%lld\n",ans);
for(int i=1;i<=k;i++) vis[h[i]]=0;
}
return 0;
}
3、CF613D Kingdom and its Cities
分析:在虛樹上樹形 \(dp\) 的時候分三種情況:
\(P.S:sum\) 表示有多少個兒子已經選了
1、\(vis[x]=1\),那麼不能選 \(x\) 來斷掉兒子的退路,那麼 \(dp[x]=\sum_{son\in x} dp[son]\)
2、\(vis[x]=0,sum>1\),那就直接選 \(x\),\(x\) 的子樹已經被 \(x\) 封死了
3、\(vis[x]=0,sum\leq 1\),那就傳到 \(x\) 的父親上,讓 \(x\) 的父親解決好了
\(Code\ Below:\)
#include <bits/stdc++.h>
using namespace std;
const int maxn=100000+10;
int n,m,dp[maxn],vis[maxn],h[maxn],sta[maxn],Top;
int fir[maxn],head[maxn],to[maxn<<1],nxt[maxn<<1],tot,cnt;
int top[maxn],dep[maxn],id[maxn],siz[maxn],son[maxn],fa[maxn],tim;
struct node{
int to,next;
}e[maxn<<1];
inline int read(){
register int x=0,f=1;char ch=getchar();
while(!isdigit(ch)){if(ch=='-')f=-1;ch=getchar();}
while(isdigit(ch)){x=(x<<3)+(x<<1)+ch-'0';ch=getchar();}
return (f==1)?x:-x;
}
inline void add(int x,int y){
e[++tot].to=y;
e[tot].next=head[x];
head[x]=tot;
}
inline void addedge(int x,int y){
to[++cnt]=y;
nxt[cnt]=fir[x];
fir[x]=cnt;
}
void dfs1(int x,int f){
siz[x]=1;fa[x]=f;
dep[x]=dep[f]+1;
int maxson=-1;
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==f) continue;
dfs1(y,x);
siz[x]+=siz[y];
if(maxson<siz[y]){
maxson=siz[y];
son[x]=y;
}
}
}
void dfs2(int x,int topf){
id[x]=++tim;
top[x]=topf;
if(son[x]) dfs2(son[x],topf);
for(int i=head[x],y;i;i=e[i].next){
y=e[i].to;
if(y==fa[x]||y==son[x]) continue;
dfs2(y,y);
}
}
int LCA(int x,int y){
while(top[x]!=top[y]){
if(dep[top[x]]<dep[top[y]]) swap(x,y);
x=fa[top[x]];
}
if(dep[x]>dep[y]) swap(x,y);
return x;
}
bool cmp(int a,int b){
return id[a]<id[b];
}
int dfs(int x){
int ans=0,sum=0;
for(int i=fir[x];i;i=nxt[i])
ans+=dfs(to[i]),sum+=dp[to[i]];
if(vis[x]) dp[x]=1,ans+=sum;
else if(sum>1) dp[x]=0,ans++;
else dp[x]=sum;
fir[x]=0;
return ans;
}
int main()
{
n=read();
int x,y,w,k,lca,flag;
for(int i=1;i<n;i++){
x=read(),y=read();
add(x,y);add(y,x);
}
dfs1(1,0);dfs2(1,1);
m=read();
for(int t=1;t<=m;t++){
k=read();
for(int i=1;i<=k;i++){
h[i]=read();
vis[h[i]]=1;
}
flag=0;
for(int i=1;i<=k;i++)
flag|=vis[fa[h[i]]];
if(flag){
printf("-1\n");
for(int i=1;i<=k;i++) vis[h[i]]=0;
continue;
}
sort(h+1,h+k+1,cmp);
cnt=0;sta[Top=1]=1;
for(int i=1;i<=k;i++){
lca=LCA(sta[Top],h[i]);
while(dep[lca]<dep[sta[Top]]){
if(dep[sta[Top-1]]<=dep[lca]){
addedge(lca,sta[Top--]);
if(sta[Top]!=lca) sta[++Top]=lca;
break;
}
addedge(sta[Top-1],sta[Top]);
Top--;
}
if(sta[Top]!=h[i]) sta[++Top]=h[i];
}
while(--Top) addedge(sta[Top],sta[Top+1]);
printf("%d\n",dfs(1));
for(int i=1;i<=k;i++) vis[h[i]]=0;
}
return 0;
}