1. 程式人生 > 實用技巧 >樹鏈剖分筆記 輕重鏈剖分

樹鏈剖分筆記 輕重鏈剖分

前置知識

  • 最近公共祖先(LCA),樹形DP,DFS序,鏈式前向星存圖,線段樹

功能

對一棵樹進行剖分,將其分成幾條鏈,將樹形變為線性,減少處理的難度
需要處理的問題有

  • 將樹從\(x\)\(y\)結點最短路徑上所有節點的值都加上\(z\)
  • 求樹從\(x\)\(y\)結點最短路徑上所有節點的值之和
  • 將以\(x\)為根節點的子樹內所有節點值都加上\(z\)
  • 求以\(x\)為根節點的子樹內所有節點值之和

定義及概念

  • 重兒子:對於每一個非葉子節點,他的兒子中以那個兒子為根的子樹的節點數最大的兒子,為該點的重兒子
  • 輕兒子:對於每個非葉子節點,他的兒子中不是重兒子的剩下的所有兒子就是輕兒子
  • 葉子節點沒有重兒子也沒有輕兒子,因為他根本沒有兒子(QAQ)
  • 重邊:一個父親連線他的重兒子的邊稱為重邊
  • 輕邊:重邊剩下的就是輕邊
  • 重鏈:相鄰的重邊連起來的連線一條重兒子的鏈稱為重鏈
    • 對於葉子節點,如果他是輕兒子,那麼有一條以他自己為起點的長度為1的鏈
    • 每一條重鏈以其輕兒子為起點

步驟

\(DFS1\)

功能
  • 標記每個點的深度
  • 標記每個點的父親
  • 標記每個非葉子的子樹大小(包括他自己)
  • 標記每個非葉子節點的重兒子的編號
程式碼實現
void dfs1(int x,int f,int deep)//x當前節點,f父親,deep深度
{
	dep[x]=deep;//標記深度 
	fa[x]=f;//標記每個點的父親 
	siz[x]=1;//標記每個非葉子節點的子樹的大小 
	int maxson=-1;//記錄重兒子的兒子數量
	for(int i=head[x];i;i=e[i].last)
	{
		int y=e[i].to;
		if(y==f) continue;//如果是父親那麼就繼續去找下一個
		dfs1(y,x,deep+1);
		siz[x]+=siz[y];//加上子樹的節點數量 
		if(siz[y]>maxson)
		{
			son[x]=y;
			maxson=siz[y];//如果該子節點更大,那麼就標記他的每個非葉子節點的
			//重兒子的編號 
		}
	} 
}

\(DFS2\)

功能
  • 標記每一個點的新編號(線上段樹裡面的)
  • 給每個點的新編號賦上這個點的初始值
  • 處理好每個點所在的鏈的頂端
  • 處理每條鏈
程式碼實現(先處理重兒子,再處理輕兒子)
void dfs2(int x,int topf)//x為當前的節點,topf為當前鏈上最頂端的節點
{
	id[x]=++cnt;//標記每個點的新編號 
	wt[cnt]=w[x];//把每個點的初始值都賦到新的編號上來
	top[x]=topf;//標記這個點所在的鏈的頂端
	if(!son[x]) return;//如果沒有重兒子(兒子),那麼就返回
	dfs2(son[x],topf);//先處理重兒子,在處理輕兒子,遞迴處理
	for(int i=head[x];i;i=e[i].last)
	{
		int y=e[i].to;
		if(y==fa[x]||y==son[x])continue;
		//如果遍歷到父親結點或者是重兒子,那麼就繼續搜尋
		dfs2(y,y);
		//每一個輕兒子都有一條從自己開始的鏈
	} 
} 

處理問題

  • 先標上新的編號


因為順序是按照先重兒子,再輕兒子來處理的,所以每一條重鏈的新編號是連續的
因為是用的\(DFS\)所以每一個子樹的新編號也是連續的

  • 首先,當我們要處理任意兩點間的路徑時:
    設我們所在練的頂端深度更深的那個點為\(x\)
    • \(ans\)加上\(x\)點到\(x\)所在鏈的頂端這一段區間的點權和
    • \(x\)跳到\(x\)所在的鏈頂端的那個點的上面的那個點
    • 不斷地執行這兩個操作,知道這兩個點處在一條鏈上面,然後此時在加上這兩個點之間的區間和

在這個時候我們注意到,我們所要處理的所有的區間都是連續的編號(新編號),那麼我們可以用線段樹處理連續編號區間和,每次查詢時間複雜度為\(O(log^2n)\)

int qRange(int x,int y)
{
	int ans=0;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])//把x改在所在鏈更深的點
		swap(x,y);
		res=0;
		query(1,1,n,id[top[x]],id[x]);//ans加上x點到所在鏈頂端的這一區間的點權和
		ans+=res;
		ans%=mod;
		x=fa[top[x]]; 
	}
	if(dep[x]>dep[y])swap(x,y);
	res=0;
	query(1,1,n,id[x],id[y]);
	ans+=res;
	return ans%mod; 
 } 

處理一點及其子樹的點權和

記錄每一個非葉子節點的子樹的大小,並且每一個子樹的新編號都是連續的,於是就直接線段樹區間查詢即可時間複雜度為\(O(log n)\)

int qson(int x)
{
	res=0;
	query(1,1,n,id[x],id[x]+siz[x]-1);//子樹的右端點為id[x]+siz[x]-1,可以手推一下
	return res;
}

區間修改

void updson(int x,int k)
{
	update(1,1,n,id[x],id[x]+siz[x]-1,k);
}
void updRange(int x,int y,int k)//區間修改
{
	k%=mod;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])//讓x的深度更深 
		swap(x,y);
		update(1,1,n,id[top[x]],id[x],k);
		x=fa[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	update(1,1,n,id[x],id[y],k);
}


完整程式碼(200行高能)

#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<stack>
#include<map>
#include<cmath>
#include<algorithm>
using namespace std;

#define mid ((l+r)>>1)
#define lson rt<<1,l,mid
#define rson rt<<1|1,mid+1,r
#define len (r-l+1)

const int N=2e5+10;
struct node{
	int last;
	int to;
}e[N];
int head[N]; 

int n,m,r,mod;
int e_cnt,w[N],wt[N];
int a[N<<2],laz[N<<2];
//線段樹陣列,懶惰標記
int son[N],id[N],fa[N],cnt,dep[N],siz[N],top[N];
//重兒子編號,新編號,父親結點,dfs序,深度,子樹的大小,當前鏈的頂端結點
int res=0;
void add(int from,int to)
{
	e[++e_cnt].last=head[from];
	e[e_cnt].to=to;
	head[from]=e_cnt;
}
//------------------------------------------------線段樹
void pushdown(int rt,int lenn)
{
	laz[rt<<1]+=laz[rt];
	laz[rt<<1|1]+=laz[rt];
	a[rt<<1]+=laz[rt]*(lenn-(lenn>>1));
	a[rt<<1|1]+=laz[rt]*(lenn>>1);
	a[rt<<1]%=mod;
	a[rt<<1|1]%=mod;
	laz[rt]=0; 
}
void build(int rt,int l,int r)
{
	if(l==r)
	{
		a[rt]=wt[l];//賦值點權值
		if(a[rt]>mod) a[rt]%=mod;
		return; 
	}
	build(lson);
	build(rson);
	a[rt]=(a[rt<<1]+a[rt<<1|1])%mod;
}
void query(int rt,int l,int r,int L,int R)
{
	if(L<=l&&r<=R)
	{
		res+=a[rt];
		res%=mod;
		return;
	}
	else 
	{
		if(laz[rt]) pushdown(rt,len);
		if(L<=mid) query(lson,L,R);
		if(R>mid) query(rson,L,R);
	}
}
void update(int rt,int l,int r,int L,int R,int k)
//當前節點,當前區間的左,右,要修改的區間左,右,修改值 
{
	if(L<=l&&r<=R)
	{
		laz[rt]+=k;
		a[rt]+=k*len;
	}
	else
	{
		if(laz[rt]) pushdown(rt,len);
		if(L<=mid) update(lson,L,R,k);
		if(R>mid) update(rson,L,R,k);
		a[rt]=(a[rt<<1]+a[rt<<1|1])%mod;
	}
} 
//------------------------------------------------樹鏈剖分
int qRange(int x,int y)
{
	int ans=0;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])//把x改在所在鏈更深的點
		swap(x,y);
		res=0;
		query(1,1,n,id[top[x]],id[x]);//ans加上x點到所在鏈頂端的這一區間的點權和
		ans+=res;
		ans%=mod;
		x=fa[top[x]]; 
	}
	if(dep[x]>dep[y])swap(x,y);
	res=0;
	query(1,1,n,id[x],id[y]);
	ans+=res;
	return ans%mod; 
 } 
void updRange(int x,int y,int k)//區間修改
{
	k%=mod;
	while(top[x]!=top[y])
	{
		if(dep[top[x]]<dep[top[y]])//讓x的深度更深 
		swap(x,y);
		update(1,1,n,id[top[x]],id[x],k);
		x=fa[top[x]];
	}
	if(dep[x]>dep[y]) swap(x,y);
	update(1,1,n,id[x],id[y],k);
}
int qson(int x)
{
	res=0;
	query(1,1,n,id[x],id[x]+siz[x]-1);
	return res;
}
void updson(int x,int k)
{
	update(1,1,n,id[x],id[x]+siz[x]-1,k);
}
void dfs1(int x,int f,int deep)//x當前節點,f父親,deep深度
{
	dep[x]=deep;//標記深度 
	fa[x]=f;//標記每個點的父親 
	siz[x]=1;//標記每個非葉子節點的子樹的大小 
	int maxson=-1;//記錄重兒子的兒子數量
	for(int i=head[x];i;i=e[i].last)
	{
		int y=e[i].to;
		if(y==f) continue;//如果是父親那麼就繼續去找下一個
		dfs1(y,x,deep+1);
		siz[x]+=siz[y];//加上子樹的節點數量 
		if(siz[y]>maxson)
		{
			son[x]=y;
			maxson=siz[y];//如果該子節點更大,那麼就標記他的每個非葉子節點的
			//重兒子的編號 
		}
	} 
}
void dfs2(int x,int topf)//x為當前的節點,topf為當前鏈上最頂端的節點
{
	id[x]=++cnt;//標記每個點的新編號 
	wt[cnt]=w[x];//把每個點的初始值都賦到新的編號上來
	top[x]=topf;//標記這個點所在的鏈的頂端
	if(!son[x]) return;//如果沒有重兒子(兒子),那麼就返回
	dfs2(son[x],topf);//先處理重兒子,在處理輕兒子,遞迴處理
	for(int i=head[x];i;i=e[i].last)
	{
		int y=e[i].to;
		if(y==fa[x]||y==son[x])continue;
		//如果遍歷到父親結點或者是重兒子,那麼就繼續搜尋
		dfs2(y,y);
		//每一個輕兒子都有一條從自己開始的鏈
	} 
} 
int main()
{
	cin>>n>>m>>r>>mod;
	for(int i=1;i<=n;i++)
	cin>>w[i];//節點的初始權值 
	for(int i=1;i<n;i++)
	{
		int x,y;
		cin>>x>>y;
		add(x,y);
		add(y,x); 
	}
	dfs1(r,0,1);
	dfs2(r,r);
	build(1,1,n);
	while(m--)
	{
		int k,x,y,z;
		cin>>k;
		if(k==1)
		{
			cin>>x>>y>>z;
			updRange(x,y,z);
		}
		else if(k==2)
		{
			cin>>x>>y;
			cout<<qRange(x,y)<<endl;
		}
		else if(k==3)
		{
			cin>>x>>y;
			updson(x,y);
		}
		else 
		{
			cin>>x;
			cout<<qson(x)<<endl;
		}
	}
	return 0;
 }