1. 程式人生 > 實用技巧 >雨天的尾巴「線段樹合併+樹上差分」

雨天的尾巴「線段樹合併+樹上差分」

雨天的尾巴「線段樹合併+樹上差分」

題目描述(簡化版)

\(N\) 個點,形成一個樹狀結構。有 %M$ 次發放,每次選擇兩個點 \(x,y\) 對於 \(x\)\(y\) 的路徑上(含 \(x,y\))每個點發一袋 \(Z\) 型別的物品。完成所有發放後,每個點存放最多的是哪種物品。

輸入格式

第一行數字 \(N,M\)
接下來 \(N-1\) 行,每行兩個數字 \(a,b\),表示 \(a\)\(b\) 間有一條邊
再接下來 \(M\) 行,每行三個數字 \(x,y,z\).如題

輸出格式

輸出有 \(N\)
\(i\) 行的數字表示第 \(i\) 個點存放最多的物品是哪一種,如果有多種物品的數量一樣,輸出編號最小的。如果某個點沒有物品則輸出 \(0\)

思路分析

板子題

  • 首先發現這是在樹上進行區間修改,所以可以用樹上差分來處理
  • 每一個節點都要記錄不同種類的物品的個數,因此每一個節點都開一個權值線段樹,以物品種類編號為下標
  • 因為使用了樹上差分,所以每一個節點的線段樹都要與其子樹內的所有節點的線段樹合併,類似字首和思想,這樣就可以得到該節點最終的線段樹,就可以更新答案了

詳見程式碼

Code

#include<cstdio>
#include<cstring>
#include<cmath>
#include<algorithm>
#define N 100005
#define M 6000005
#define R register
using namespace std;
inline int read(){
	int x = 0,f = 1;
	char ch = getchar();
	while(ch>'9'||ch<'0'){if(ch=='-')f=-1;ch=getchar();}
	while(ch>='0'&&ch<='9'){x=(x<<1)+(x<<3)+(ch^48);ch=getchar();}
	return x*f;
}
int n,m,ed,head[N],siz[N],dep[N],f[N],son[N],top[N];
int tr[M],mx[M],ls[M],rs[M];//tr記錄的是線段樹中每個物品的種類數,mx記錄的是種類數最多的物品
int ans[N],cnt,x[N],y[N],z[N],rt[N]; //rt表示每個結點的線段樹的根節點
struct edge{
	int to,next;
}e[N<<1];
int len;
void addedge(int u,int v){
	e[++len].to = v;
	e[len].next = head[u];
	head[u] = len;
}
void dfs1(int u,int fa){ //樹剖求LCA
	siz[u] =  1;
	dep[u] = dep[fa]+1;
	f[u] = fa;
	for(R int i = head[u];i;i = e[i].next){
		int v = e[i].to;
		if(v==fa)continue;
		dfs1(v,u);
		siz[u] += siz[v];
		if(siz[v]>siz[son[u]])son[u] = v;
	}
}
void dfs2(int u,int tp){
	top[u] = tp;
	if(son[u])dfs2(son[u],tp);
	for(R int i = head[u];i;i = e[i].next){
		int v = e[i].to;
		if(v!=son[u]&&v!=f[u])dfs2(v,v);
	}
}
int LCA(int x,int y){
	while(top[x]!=top[y]){
		if(dep[top[x]]<dep[top[y]])swap(x,y);
		x = f[top[x]];
	}
	return dep[x] < dep[y] ? x : y;
}//————————————手動分割線——————————————

//以下為線段樹操作
void pushup(int rt){
    if(tr[ls[rt]]>=tr[rs[rt]])tr[rt] = tr[ls[rt]],mx[rt] = mx[ls[rt]];
    else tr[rt] = tr[rs[rt]],mx[rt] = mx[rs[rt]];
}
int modify(int rt,int l,int r,int pos,int val){
	if(!rt) rt = ++cnt; //動態開點
	if(l==r){
		tr[rt] += val;
		mx[rt] = l;
		return rt;
	}
	int mid = (l+r)>>1;
	if(pos<=mid)ls[rt] = modify(ls[rt],l,mid,pos,val);
	else rs[rt] = modify(rs[rt],mid+1,r,pos,val);
	pushup(rt);
	return rt;
}
int merge(int a,int b,int l,int r){ //線段樹合併
	if(!a)return b;
	if(!b)return a;
	if(l==r){
		tr[a] += tr[b];
		mx[a] = l;
		return a;
	}
	int mid = (l+r)>>1;
	ls[a] = merge(ls[a],ls[b],l,mid),rs[a] = merge(rs[a],rs[b],mid+1,r);
	pushup(a);
	return a;
}
void get_ans(int u){//dfs統計差分答案
	for(R int i = head[u];i;i = e[i].next){
		int v = e[i].to;
		if(dep[v]>dep[u]){
			get_ans(v);
			rt[u] = merge(rt[u],rt[v],1,ed);
		}
	}
	if(tr[rt[u]])ans[u] = mx[rt[u]];
}
int main(){
	n = read(),m = read();
	for(R int i = 1;i < n;i++){
		int a = read(),b = read();
		addedge(a,b),addedge(b,a);
	}
	dfs1(1,0),dfs2(1,1);//注意是dfs1(1,0)而不是dfs1(1,1),因為這個坑花了多長時間就不提了:-)
	for(R int i = 1;i <= m;i++){
		x[i] = read(),y[i] = read(),z[i] = read();
		ed = max(ed,z[i]);
	}
	for(R int i=1;i<=m;i++){//差分處理
		int lca=LCA(x[i],y[i]);
		rt[x[i]]=modify(rt[x[i]],1,ed,z[i],1),rt[y[i]]=modify(rt[y[i]],1,ed,z[i],1);
		rt[lca]=modify(rt[lca],1,ed,z[i],-1);
		if(f[lca]) rt[f[lca]]=modify(rt[f[lca]],1,ed,z[i],-1);
        }
	get_ans(1);
	for(R int i = 1;i <= n;i++)printf("%d\n",ans[i]);
	return 0;
}