1. 程式人生 > 其它 >虛樹 學習筆記

虛樹 學習筆記

聽起來很高階的資料結構,但實際上很好理解。

主要用於處理標記 \(k\) 個特殊點,進行一些詢問的問題,一般題中會給出 \(\sum k\) 的資料範圍。


1. Idea

虛樹其實就是將一棵樹的所有特殊點以及特殊點兩兩之間的 \(\text{lca}\) 構成一棵新的樹

具體來說,我們可以預處理出樹上所有節點的 \(\text{dfs}\) 序,然後將特殊點按照 \(\text{dfs}\) 序排序,將排序後的每兩個相鄰的特殊點之間求一個 \(\text{lca}\),並將其加入特殊點的行列。

為什麼 \(\text{dfs}\) 序相鄰的特殊點的 \(\text{lca}\) 就可以涵蓋所有特殊點兩兩之間的 \(\text{lca}\)

了呢?

考慮對於三個 \(\text{dfs}\) 序分別為 \(a,b,c\) 的節點,其中 \(a<b<c\) 。若 \(b\)\(a\) 的子樹內,則 \(b\)\(c\)\(\text{lca}\) 顯然也是 \(a\)\(c\)\(\text{lca}\);若 \(b\)\(a\) 的子樹外,不妨設 \(c\)\(b\)\(\text{lca}\) 深度小於 \(a\)\(b\)\(\text{lca}\) ,則前者必定也是後者的祖先(畫個圖很明顯),也是 \(a\) 的祖先,若不是 \(a\)\(c\)

\(\text{lca}\),則一定也不是 \(b\)\(c\)\(\text{lca}\),故這兩個 \(\text{lca}\) 一定也包含 \(a\)\(c\)\(\text{lca}\)。故推廣後可知,必定涵蓋。(怎麼感覺像繞口令)

綜上所述,記住就可以啦!(大霧)

處理完之後,我們再將記錄特殊點的陣列排序一次,去重後考慮建樹。

由於特殊點是按照 \(\text{dfs}\) 序排序的,所以其實儲存特殊點的陣列也是按照新樹的 \(\text{dfs}\) 序排序的。那麼建樹的問題其實就轉化為,給定一棵樹的 \(\text{dfs}\) 序,建出這棵樹。這個問題的解決方式自然很簡單,用一個類似於單調棧的棧維護當前節點的祖先,每次彈出棧頂不符合條件的祖先後棧頂的祖先即為該節點在新樹中的父親,然後將該節點加入棧頂即可。

建好樹之後,我們就可以在樹上進行 dp 等操作處理答案了。

特殊點一共 \(k\) 個,擴充套件完 \(\text{lca}\) 後最多 \(2k-1\) 個,故 dp 等操作的複雜度可以被成功地降到 \(O(\sum k)\) 相關,而由於建樹時需要獲得 \(k-1\)\(\text{lca}\),故所有詢問的建樹總複雜度為 \(O(\sum k\times\log n)\)。且為線上演算法

需要注意新樹中的根節點並不一定為原樹中的根節點,一般情況下,新樹中的根節點是所有特殊點當中 \(\text{dfs}\) 序最小的節點。


2. Example

2.1 CF613D Kingdom and its Cities

給定一棵樹,每次詢問給定 \(k\) 個特殊點,找出儘量少的非特殊點使得刪去這些點後特殊點兩兩不連通。\(\sum k\le n.\)

觀察到 選點+限制所有詢問總點數,我們基本就可以確定這是一道虛樹的題目。

先建樹,如果兩個特殊點的距離為 \(1\),則無解。

否則對於一個節點,如果是特殊點,則需要斷開其所有需要斷開的兒子,並向上傳遞一個需要被斷開的標記;如果不是特殊點,且兒子中需要被斷開的標記個數大於等於 \(2\),則斷開當前節點,如果個數為 \(1\),則向上傳遞需要被斷開的標記,個數為 \(0\) 則不傳。最後的答案即為過程中一共被斷開的點數。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define N 100005
int n,m,k,jtot,dtot,ltot,st[N<<1],stot,ans;
bool can[N],flag;
struct node{
	int to,next;
	node (int to=0,int next=0)
		:to(to),next(next){}
};
struct node2{
	int head[N],tot;
	node e[N<<1];
	void adde(int u,int v){
		e[++tot]=node(v,head[u]);
		head[u]=tot;
	}
}S,T;
struct node1{
	int fa,tp,zson,size,dep;
}e[N];
struct node3{
	int dfn,low,id;
}p[N],jl[N<<1];
int read(){
	int wh=0,fh=1;
	char c=getchar();
	while (c>'9'||c<'0'){
		if (c=='-') fh=-1;
		c=getchar();
	}
	while (c>='0'&&c<='9'){
		wh=(wh<<3)+(wh<<1)+(c^48);
		c=getchar();
	}
	return wh*fh;
} 
void dfs1(int u){
	p[u].dfn=++dtot,e[u].size=1;
	for (int i=S.head[u];i;i=S.e[i].next){
		int v=S.e[i].to;
		if (v!=e[u].fa){
			e[v].fa=u;
			e[v].dep=e[u].dep+1;
			dfs1(v);
			e[u].size+=e[v].size;
			if (e[v].size>e[e[u].zson].size) e[u].zson=v;
		}
	}
	p[u].low=++ltot,p[u].id=u;
}
void dfs2(int u,int tp){
	e[u].tp=tp;
	if (e[u].zson) dfs2(e[u].zson,tp);
	for (int i=S.head[u];i;i=S.e[i].next){
		int v=S.e[i].to;
		if (v!=e[u].fa&&v!=e[u].zson){
			dfs2(v,v);
		}
	}
}
bool cmp(node3 x,node3 y){
	return x.dfn<y.dfn;
}
int getlca(int x,int y){
	while (e[x].tp!=e[y].tp){
		int xp=e[x].tp,yp=e[y].tp;
		if (e[xp].dep>e[yp].dep) x=e[xp].fa;
		else y=e[yp].fa;
	}
	return e[x].dep<e[y].dep?x:y;
}
void build(){
	sort(jl+1,jl+1+jtot,cmp);
	for (int i=1;i<k;i++){
		jl[++jtot]=p[getlca(jl[i].id,jl[i+1].id)];
	}
	sort(jl+1,jl+1+jtot,cmp);
	stot=0;
//	for (int i=1;i<=jtot;i++) printf("%d ",jl[i].id);
//	puts("--------");
	st[++stot]=jl[1].id;
	for (int i=2;i<=jtot;i++){
		if (jl[i].dfn==jl[i-1].dfn) continue;
		while (stot&&(p[st[stot]].dfn>jl[i].dfn||p[st[stot]].low<jl[i].low)) --stot;
//		printf("%d %d\n",jl[i].id,st[stot]);
		if (st[stot]==e[jl[i].id].fa&&can[st[stot]]&&can[jl[i].id]){
			flag=1;
			puts("-1");
			break;
		}
		T.adde(st[stot],jl[i].id);
		st[++stot]=jl[i].id;
	}
}
void to_pre(){
	for (int i=1;i<=jtot;i++) can[jl[i].id]=0,T.head[jl[i].id]=0;
	T.tot=0;
}
int dp(int u){
//	printf("::%d\n",u);
	int sum=0;
	for (int i=T.head[u];i;i=T.e[i].next){
		int v=T.e[i].to;
		sum+=dp(v);
	}
	if (can[u]){
		ans+=sum;
		return 1;
	}else if (sum==0) return 0;
	else if (sum==1) return 1;
	ans++;return 0;
}
int main(){
	n=read();
	for (int i=1;i<n;i++){
		int u=read(),v=read();
		S.adde(u,v),S.adde(v,u);
	}
	dfs1(1),dfs2(1,1);
	m=read();
	while (m--){
		k=read();
		jtot=0;
		for (int i=1;i<=k;i++) jl[++jtot]=p[read()],can[jl[jtot].id]=1;
		build();
		if (!flag) ans=0,dp(st[1]),printf("%d\n",ans);
		else flag=0;
		to_pre();
	}
	return 0;
}

2.2 P2495 [SDOI2011]消耗戰

給定一棵樹,每次詢問給定 \(k\) 個特殊點,需要斷掉一些邊使得從根節點無法到達任何特殊點,求最小需要斷掉的邊數。\(\sum k\le2n\).

同樣,觀察到給定特殊點以及關於所有詢問特殊點總個數的限制,考慮使用虛樹。

很容易想到對於一個虛樹上的點,如果是特殊點,那麼其實斷掉其子樹上的邊已經沒有用了,只能斷掉該節點到根路徑上的邊;如果不是特殊點,則將所有子樹返回上來的答案加起來之後,與到其父親的邊權取 \(\min\) 後上傳給父親(即要不斷掉所有有特殊點的子樹,要不斷掉與根路徑上的邊)。最終 \(1\) 節點的答案即為最終答案。

注意由於敵軍島嶼在 \(1\) 號,所以如果建立虛樹後虛樹的根不是 \(1\) 號節點的話,還需要將 \(1\) 號節點再向虛樹的根節點連一條邊,並將 \(1\) 號節點作為根節點(注意,多測清空的時候 \(1\) 號節點相關的資訊也要清空!)。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define N 500005
int n,fa[N][21],dep[N],m,k,jtot,dtot,ltot,st[N],stot;
ll minn[N][21];
bool can[N];
struct node{
	int to,next;
	ll w;
	node (int to=0,int next=0,ll w=0)
		:to(to),next(next),w(w){}
};
struct node1{
	int head[N],tot;
	node e[N<<1];
	void adde(int u,int v,int w){
		e[++tot]=node(v,head[u],w);
		head[u]=tot;
	}
}S,T;
struct node2{
	int dfn,low,id;	
}p[N],jl[N<<1];
int read(){
	int wh=0,fh=1;
	char c=getchar();
	while (c>'9'||c<'0'){
		if (c=='-') fh=-1;
		c=getchar();
	}
	while (c>='0'&&c<='9'){
		wh=(wh<<3)+(wh<<1)+(c^48);
		c=getchar();
	}
	return wh*fh;
} 
void dfs(int u){
	p[u].dfn=++dtot,p[u].id=u;
	for (int i=S.head[u];i;i=S.e[i].next){
		int v=S.e[i].to;
		if (v!=fa[u][0]){
			fa[v][0]=u;
			dep[v]=dep[u]+1;
			minn[v][0]=S.e[i].w;
			for (int i=1;(1<<i)<=dep[v];i++) fa[v][i]=fa[fa[v][i-1]][i-1],minn[v][i]=min(minn[v][i-1],minn[fa[v][i-1]][i-1]);
			dfs(v);
		}
	}
	p[u].low=++ltot;
}
bool cmp(node2 x,node2 y){
	return x.dfn<y.dfn;
}
int getlca(int x,int y){
	if (dep[x]<dep[y]) swap(x,y);
	for (int i=20;i>=0;i--)
		if (dep[x]-(1<<i)>=dep[y]) x=fa[x][i];
	if (x==y) return x;
	for (int i=20;i>=0;i--)
		if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}
ll query(int x,int y){
	ll sum=1e15;
	for (int i=20;i>=0;i--)
		if (dep[y]-(1<<i)>=dep[x]) sum=min(sum,minn[y][i]),y=fa[y][i];
	return sum;
}
void build(){
	sort(jl+1,jl+1+jtot,cmp);
	for (int i=1;i<k;i++) jl[++jtot]=p[getlca(jl[i].id,jl[i+1].id)];
	sort(jl+1,jl+1+jtot,cmp);
//	for (int i=1;i<=jtot;i++) printf("%d ",jl[i].id);
//	puts("--------");
	st[stot=1]=jl[1].id;
	for (int i=2;i<=jtot;i++){
		if (jl[i].dfn==jl[i-1].dfn) continue;
		while (stot&&(p[st[stot]].dfn>jl[i].dfn||p[st[stot]].low<jl[i].low)) --stot;
		T.adde(st[stot],jl[i].id,query(st[stot],jl[i].id));
		st[++stot]=jl[i].id;
	}
	if (st[1]!=1) T.adde(1,st[1],query(1,st[1]));
}
void to_pre(){
	can[1]=T.head[1]=0;
	for (int i=1;i<=jtot;i++) can[jl[i].id]=0,T.head[jl[i].id]=0;
	T.tot=0;
}
ll dp(int u){
	ll sum=0;
	for (int i=T.head[u];i;i=T.e[i].next){
		int v=T.e[i].to;
		ll w=min(dp(v),T.e[i].w);
//		printf("%d %d %d\n",u,v,w);
		sum+=w;
	}
	if (can[u]) return 1e12;
	return sum;
}
int main(){
	n=read();
	for (int i=1;i<n;i++){
		int u=read(),v=read(),w=read();
		S.adde(u,v,w),S.adde(v,u,w);
	}
	dfs(1);
	m=read();
	while (m--){
		k=read();
		jtot=0;
		for (int i=1;i<=k;i++) jl[++jtot]=p[read()],can[jl[i].id]=1;
		build();
		printf("%lld\n",dp(1));
		to_pre();
	}
	return 0;
}

2.3 P4103 [HEOI2014]大工程

給定一棵樹,每次詢問給定 \(k\) 個特殊點,求它們兩兩之間距離的距離和,最小距離和最大距離。\(\sum k\le2n\).

有了前兩題的經驗,很自然地就想到了虛樹。

先想想普通方法怎麼做。

分別考慮三個問題。距離和其實最好求,列舉每一條邊,該邊的貢獻其實就是其左側的特殊點個數和其右邊的特殊點個數的乘積再乘上該邊的邊權,對所有邊的貢獻求和即為答案。

最小距離和最大距離方法類似,考慮樹形 dp,對於每個節點分別其子樹內過該節點的答案。對於最小距離,遍歷到一個節點時,如果該節點為特殊點,則該節點子樹內的最小距離顯然是其子樹內特殊點到該節點的最小距離;如果不是,最小距離則為其子樹內特殊點到該節點的最小距離和次小距離之和。對於最大距離,則最大距離為子樹內最大距離與次大距離之和,如果沒有次大距離且當前節點為特殊點時,次大距離可以用 \(0\) 代替。而這些東西都可以通過簡單的遞迴與返回處理。

那麼其實虛樹上的求解和普通樹上的求解基本相同,只不過需要將虛樹上的鏈權賦值為該鏈上邊權的最小值,可以通過倍增或樹剖輕鬆處理。

#include<bits/stdc++.h>
using namespace std;
typedef long long ll;
#define N 1000005
int n,dtot,ltot,jtot,fa[N][22],dep[N],st[N<<1],stot,ans2,ans3,size[N],k;
ll ans1;
bool can[N];
struct node{
	int head[N],tot;
	int to[N<<1],next[N<<1],w[N<<1],from[N<<1];
	void adde(int u,int v,int ww){
		++tot;
		from[tot]=u,to[tot]=v,next[tot]=head[u],w[tot]=ww;
		head[u]=tot;
	}
}S,T;
struct node1{
	int id,dfn,low;
}p[N],jl[N<<1];
int read(){
	int wh=0,fh=1;
	char c=getchar();
	while (c>'9'||c<'0'){
		if (c=='-') fh=-1;
		c=getchar();
	}
	while (c>='0'&&c<='9'){
		wh=(wh<<3)+(wh<<1)+(c^48);
		c=getchar();
	}
	return wh*fh;
} 

void dfs(int u){
	p[u].dfn=++dtot;
	for (int i=S.head[u];i;i=S.next[i]){
		int v=S.to[i];
		if (v!=fa[u][0]){
			dep[v]=dep[u]+1;
			fa[v][0]=u;
			for (int i=1;(1<<i)<=dep[v];i++) fa[v][i]=fa[fa[v][i-1]][i-1];
			dfs(v);
		}
	}
	p[u].low=++ltot;
	p[u].id=u;
}
bool cmp(node1 x,node1 y){
	return x.dfn<y.dfn;
}
int getlca(int x,int y){
	if (dep[x]<dep[y]) swap(x,y);
	for (int i=21;i>=0;i--)
		if (dep[x]-(1<<i)>=dep[y]) x=fa[x][i];
	if (x==y) return x;
	for (int i=21;i>=0;i--)
		if (fa[x][i]!=fa[y][i]) x=fa[x][i],y=fa[y][i];
	return fa[x][0];
}
void check(int u){
	printf("-----%d\n",u);
	for (int i=T.head[u];i;i=T.next[i]){
		int v=T.to[i];
		check(v);
	}
}
void build(){
	k=read();
	jtot=0;
	for (int i=1;i<=k;i++) jl[++jtot]=p[read()],can[jl[i].id]=1;
	sort(jl+1,jl+1+jtot,cmp);
	for (int i=1;i<k;i++)
		jl[++jtot]=p[getlca(jl[i].id,jl[i+1].id)];
	sort(jl+1,jl+1+jtot,cmp);
	stot=0;
	st[++stot]=jl[1].id;
//	puts("------");
//	for (int i=1;i<=jtot;i++) printf("%d ",jl[i].id);
//	puts("\n------");
	for (int i=2;i<=jtot;i++){
		if (jl[i].dfn==jl[i-1].dfn) continue;
	//	printf("%d %d %d %d %d %d\n",st[stot],jl[i].id,p[st[stot]].dfn,jl[i].dfn,p[st[stot]].low,jl[i].low);
		while (stot&&(p[st[stot]].low<jl[i].low||p[st[stot]].dfn>jl[i].dfn)) --stot;
		T.adde(st[stot],jl[i].id,dep[jl[i].id]-dep[st[stot]]);
	//	printf("%d %d %d\n",st[stot],jl[i].id,dep[jl[i].id]-dep[st[stot]]);
		st[++stot]=jl[i].id;
	}
//	check(st[1]);
}
void dfs1(int u){
	if (can[u]) size[u]=1;
	for (int i=T.head[u];i;i=T.next[i]){
		int v=T.to[i];
		dfs1(v);
		size[u]+=size[v];	
	}
}
void dfs2(int u){
	size[u]=0;
	for (int i=T.head[u];i;i=T.next[i]){
		int v=T.to[i];
		dfs2(v);
	}
}
int query_sum(){
	dfs1(st[1]);
	for (int i=1;i<=T.tot;i++){
		int u=T.from[i],v=T.to[i];
		ans1+=(ll)size[v]*(ll)(k-size[v])*(ll)T.w[i];
	}
	dfs2(st[1]);
}
int query_max(int u){
	int maxn=0,cmax=0;
	for (int i=T.head[u];i;i=T.next[i]){
		int v=T.to[i],w=T.w[i];
		int now=query_max(v)+w;
		if (now>maxn) cmax=maxn,maxn=now;
		else if (now>cmax) cmax=now;
	}
	if (cmax||maxn&&can[u]) ans3=max(ans3,maxn+cmax);
	return maxn;
}
int query_min(int u){
	int minn=1e9,cmin=1e9;
	for (int i=T.head[u];i;i=T.next[i]){
		int v=T.to[i],w=T.w[i];
		int now=query_min(v)+w;
		if (now<minn) cmin=minn,minn=now;
		else if (now<cmin) cmin=now;
	}
	if (can[u]){
		if (minn) ans2=min(ans2,minn);
		return 0;
	}
	if (cmin)
		ans2=min(ans2,minn+cmin);
	return minn;
}
int main(){
	n=read();
	for (int i=1;i<n;i++){
		int u=read(),v=read();
		S.adde(u,v,1),S.adde(v,u,1);
	}
	dfs(1);
	int q=read();
	while (q--){
		build();
		ans1=ans3=0;ans2=1e9;
		query_sum(),query_min(st[1]),query_max(st[1]);
		printf("%lld %d %d\n",ans1,ans2,ans3);
		for (int i=1;i<=jtot;i++) T.head[jl[i].id]=0,can[jl[i].id]=0;
		T.tot=0;
	}
	return 0;
}