1. 程式人生 > 實用技巧 >線段樹合併 從入門到入土

線段樹合併 從入門到入土

前置知識:

  • 動態開點線段樹
  • 權值線段樹

如果你上面那兩個不會的話,出門右轉模板區。

線段樹合併是什麼東東呢?

他其實就是把好幾個零散的線段樹合併在一起。

就相當於重新開一顆權值線段樹儲存原來兩棵線段樹的資訊。

他一般可以用來解決一些平衡樹能做的題比如第\(k\) 大,排名,找前驅後繼。

大體的實現思路:

  • \(x,y\) 一個為空節點,則以非空的最為合併之後的節點。

  • \(x,y\) 都不為空,則遞迴合併左右子樹,以 \(x\) 作為合併之後的節點,並自下而上合併子樹的資訊

思想理解了的話,程式碼就不難實現了。

在這裡 \(genshy\) 給大家提供兩種不同的寫法(自認為比較方便)。

1.合併區間最大值

這個時候我們還需要添兩個引數記錄一下區間的左右端點。畢竟葉子節點和非葉子節點的合併方式是不太一樣的。

葉子節點的話直接把 \(x,y\) 兩個值相加就可以,非葉子節點的話就可以由下面子樹 \(up\) 上來

Code

void merage(int &x,int y,int l,int r)
{
	if(!x) {x = y; return;}//非空節點
	if(!y) return;
	int mid = (l + r)>>1;
	if(l == r) //葉子節點直接把權值相加
	{
		tr[x].sum += tr[y].sum;
		return;
	}
	merage(tr[x].lc,tr[y].lc,l,mid);//遞迴合併左右子樹
	merage(tr[x].rc,tr[y].rc,mid+1,r);
	up(x);//up一下
}

2.合併區間和

這種型別我們就可以少傳記錄區間端點的兩個引數,直接把 \(x,y\) 兩個節點的值相加就完事了。

Code

void merage(int &x,int y)
{
 if(!x) {x = y; return;}
 if(!y) return;
 tr[x].sum += tr[y].sum;
 merage(tr[x].lc,tr[y].lc);
 merage(tr[x].rc,tr[y].rc);
}

一張很透徹的影象:

複雜度證明:

具體的我不太會證,所以直接把日報上的搬過來了。

先來思考一下在動態開點線段樹中插入一個點會加入多少個新的節點
線段樹從頂端到任意一個葉子結點之間有 \(logn\)

層,每層最多新增一個節點
所以插入一個新的點複雜度是 \(logn\)

兩棵線段樹合併的複雜度顯然取決於兩棵線段樹重合的葉子節點個數,假設有 \(m\) 個重合的點,這兩棵線段樹合併的複雜度就是 \(mlogn\) 了,所以說,如果要合併兩棵滿滿的線段樹,這個複雜度絕對是遠大於 \(logn\) 級別的。
也就是說,千萬不要以為線段樹合併對於任何情況都是 \(logn\) 的!

那麼為什麼資料範圍 \(10^5\) 的題目線段樹合併還穩得一批?
這是因為 \(logn\) 的複雜度僅適用於插入點少的情況。
如果 \(n\) 與加入的總點數規模基本相同,我們就可以把它理解成每次操作 \(O(logn)\)

來證明一下:
假設我們會加入 \(k\) 個點,由上面的結論,我們可以推出最多要新增 \(klogk\) 個點。
而正如我們所知,每次合併兩棵線段樹同位置的點,就會少掉一個點,複雜度為 \(O(1)\),總共 \(klogk\)個點,全部合併的複雜度就是 \(O(klogk)\)

可見,上面那個證明是隻與插入點個數 \(k\) 有關,也就是插入次數在\(10^5\)左右、值域 \(10^5\)左右的題目,線段樹合併還是比較穩的。

下面我們就來看幾道例題吧QAQ。

P3605 [USACO17JAN]Promotion Counting P

比較板的題了。

對於每個節點都開一個權值線段樹,dfs的時候往上合併一下子樹的資訊。

\(x\) 答案就是 \([a[x]+1,n]\) 的區間和。

注意要離散化一下,陣列儘量開大點。

Code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N = 1e5+10;
int n,tot,u,cnt;
int head[N],rt[N],a[N],b[N],ans[N];
inline int read()
{
	int s = 0,w = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
	return s * w;
}
struct node
{
	int to,net;
}e[N<<1];
struct Tree
{
	int lc,rc,sum;
}tr[N*20];
void add(int x,int y)
{
	e[++tot].to = y;
	e[tot].net = head[x];
	head[x] = tot;
}
void insert(int &p,int l,int r,int x,int val)//動態開點
{
	if(!p) p = ++cnt;
	tr[p].sum += val;
	int mid = (l + r)>>1;
	if(l == r) return;
	if(x <= mid) insert(tr[p].lc,l,mid,x,val);
	if(x > mid) insert(tr[p].rc,mid+1,r,x,val);
	tr[p].sum = tr[tr[p].lc].sum + tr[tr[p].rc].sum;
}
int query(int o,int l,int r,int L,int R)
{
	int res = 0;
	if(!o) return 0;
	if(L <= l && R >= r) return tr[o].sum;
	int mid = (l + r)>>1;
	if(L <= mid) res += query(tr[o].lc,l,mid,L,R);
	if(R > mid) res += query(tr[o].rc,mid+1,r,L,R);
	return res;
}
void merage(int &x,int y)
{
	if(!x) {x = y; return;}
	if(!y) return;
	tr[x].sum += tr[y].sum;
	merage(tr[x].lc,tr[y].lc);
	merage(tr[x].rc,tr[y].rc);
}
void dfs(int x,int fa)
{
	insert(rt[x],1,n,a[x],1);
	for(int i = head[x]; i; i = e[i].net)
	{
		int to = e[i].to;
		if(to == fa) continue;
		dfs(to,x);
		merage(rt[x],rt[to]);//合併一下
	}
	ans[x] = query(rt[x],1,n,a[x]+1,n);//統計一下答案
}
int main()
{
	n = read();
	for(int i = 1; i <= n; i++) a[i] = b[i] = read();
	sort(b+1,b+n+1);
	int num = unique(b+1,b+n+1)-b-1;
	for(int i = 1; i <= n; i++) a[i] = lower_bound(b+1,b+num+1,a[i])-b;//離散化
	for(int i = 2; i <= n; i++)
	{
		u = read();
		add(u,i); add(i,u);
	}
	dfs(1,1);//dfs統計答案
	for(int i = 1; i <= n; i++) printf("%d\n",ans[i]);
	return 0;
}
P4556 [Vani有約會]雨天的尾巴 /【模板】線段樹合併

真正的模板題來了。

考慮對每種糧食樹上差分一下, \(d[x]+1,d[y]+1,d[lca]-1,d[fa[lca]]-1\),隨便拿線段樹維護一下這些差分陣列。

最後在 \(dfs\) 一下統計每個點的答案,順便把兒子節點的線段樹和父親節點合併一下。

然後這道題就做完了。

Code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int lim = 1e5;
const int N = 3e5+10;
int n,m,tot,cnt,x,y,z,u,v;
int head[N],dep[N],fa[N],siz[N],son[N],top[N],rt[N],ans[N];
struct node
{
	int to,net;
}e[N<<1];
struct Tree
{
	int lc,rc;
	int sum,id;
}tr[N<<4];
inline int read()
{
	int s = 0,w = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
	return s * w;
}
void add(int x,int y)
{
	e[++cnt].to = y;
	e[cnt].net = head[x];
	head[x] = cnt;
}
void get_tree(int x)//樹剖求lca
{
	dep[x] = dep[fa[x]] + 1; siz[x] = 1;
	for(int i = head[x]; i; i = e[i].net)
	{
		int to = e[i].to;
		if(to == fa[x]) continue;
		fa[to] = x;
		get_tree(to);
		siz[x] += siz[to];
		if(siz[to] > siz[son[x]]) son[x] = to;
	}
}
void dfs(int x,int topp)
{
	top[x] = topp;
	if(son[x]) dfs(son[x],topp);
	for(int i = head[x]; i; i = e[i].net)
	{
		int to = e[i].to;
		if(to == fa[x] || to == son[x]) continue;
		dfs(to,to);
	}
}
int lca(int x,int y)
{
	while(top[x] != top[y])
	{
		if(dep[top[x]] < dep[top[y]]) swap(x,y);
		x = fa[top[x]];
	}
	return dep[x] <= dep[y] ? x : y;
}
void up(int p)
{
	tr[p].sum = max(tr[tr[p].lc].sum,tr[tr[p].rc].sum);
	tr[p].id = tr[p].sum == tr[tr[p].lc].sum ? tr[tr[p].lc].id : tr[tr[p].rc].id;
}
void insert(int &p,int l,int r,int x,int val)
{
	if(!p) p = ++tot;
	if(l == r)
	{
		tr[p].sum += val;
		tr[p].id = x;
		return;
	}
	int mid = (l + r)>>1;
	if(x <= mid) insert(tr[p].lc,l,mid,x,val);
	if(x > mid) insert(tr[p].rc,mid+1,r,x,val);
	up(p);
}
pair<int,int> query(int o,int l,int r,int L,int R)
{
	int ans = 0, id = 0;
	if(!o) return make_pair(0,0);
	if(L <= l && R >= r)
	{
		if(tr[o].sum == 0) return make_pair(0,0);
		else return make_pair(tr[o].sum,tr[o].id);
	}
	int mid = (l + r)>>1;
	if(L <= mid) 
	{
		pair<int,int> kk = query(tr[o].lc,l,mid,L,R);
		if(ans < kk.first)
		{
			ans = kk.first;
			id = kk.second;
		}
	}
	if(R > mid)
	{
		pair<int,int> kk = query(tr[o].rc,mid+1,r,L,R);
		if(ans < kk.first)
		{
			ans = kk.first;
			id = kk.second;
		}
	}
	return make_pair(ans,id);
}
void merage(int &x,int y,int l,int r)
{
	if(!x) {x = y; return;}
	if(!y) return;
	int mid = (l + r)>>1;
	if(l == r) 
	{
		tr[x].sum += tr[y].sum;
		return;
	}
	merage(tr[x].lc,tr[y].lc,l,mid);
	merage(tr[x].rc,tr[y].rc,mid+1,r);
	up(x);
}
void get_ans(int x,int fa)
{
	for(int i = head[x]; i; i = e[i].net)
	{
		int to = e[i].to;
		if(to == fa) continue;
		get_ans(to,x);
		merage(rt[x],rt[to],1,lim);//合併兒子節點
	}
	pair<int,int> kk = query(rt[x],1,lim,1,lim);
	ans[x] = kk.second;
}
int main()
{
	n = read(); m = read();
	for(int i = 1; i <= n-1; i++)
	{
		u = read(); v = read();
		add(u,v); add(v,u);
	}
	get_tree(1); dfs(1,1);
	for(int i = 1; i <= m; i++)
	{
		x = read(); y = read(); z = read();
		int Lca = lca(x,y);
		insert(rt[x],1,lim,z,1);//樹上差分
		insert(rt[y],1,lim,z,1);
		insert(rt[Lca],1,lim,z,-1);
		insert(rt[fa[Lca]],1,lim,z,-1);
	}
	get_ans(1,1);//統計答案
	for(int i = 1; i <= n; i++) printf("%d\n",ans[i]);
	return 0;
}

P3224 [HNOI2012]永無鄉

板子題。

對於每個聯通塊都開個權值線段樹。並查集維護這些塊的連通性。

並查集合並的時候,順便把這兩個聯通塊的線段樹合併一下。

注意並查集合並的方向要和線段樹合併的方向相同。

Code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N = 1e5+10;
int n,m,q,u,v,tot,x,y;
int p[N],fa[N],rt[N];
struct Tree
{
	int lc,rc,sum,id;
}tr[N*20];
inline int read()
{
	int s = 0,w = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
	return s * w;
}
int find(int x)
{
	if(fa[x] == x) return x;
	else return fa[x] = find(fa[x]);
}
void insert(int &p,int l,int r,int x,int val,int id)
{
	if(!p) p = ++tot;	
	tr[p].sum += val;
	if(l == r) {tr[p].id = id; return;}
	int mid = (l + r)>>1;
	if(x <= mid) insert(tr[p].lc,l,mid,x,val,id);
	if(x > mid) insert(tr[p].rc,mid+1,r,x,val,id);
	tr[p].sum = tr[tr[p].lc].sum + tr[tr[p].rc].sum;
}
int query(int o,int l,int r,int k)//線段樹二分求區間第k大
{
	if(l == r) return tr[o].id;
	if(!o) return -1;
	int mid = (l + r)>>1;
	if(tr[tr[o].lc].sum >= k) return query(tr[o].lc,l,mid,k);
	else return query(tr[o].rc,mid+1,r,k-tr[tr[o].lc].sum);
}
void merage(int &x,int y)
{
	if(!x){x = y; return;}
	if(!y) return;
	tr[x].sum += tr[y].sum;
	tr[x].id += tr[y].id;
	merage(tr[x].lc,tr[y].lc);
	merage(tr[x].rc,tr[y].rc);
}
int main()
{
	n = read(); m = read();
	for(int i = 1; i <= n; i++)
	{
		p[i] = read();
		insert(rt[i],1,n+1,p[i],1,i);
	}
	for(int i = 1; i <= n; i++) fa[i] = i;
	for(int i = 1; i <= m; i++)
	{
		u = read(); v = read();
		int fx = find(u);
		int fy = find(v);
		if(fx == fy) continue;
		fa[fy] = fx;//並查集合並的方向要和線段樹合併的方向相同
		merage(rt[fx],rt[fy]);//合併兩個聯通塊的線段樹
	}
	q = read();
	for(int i = 1; i <= q; i++)
	{
		char ch; cin>>ch;
		x = read(); y = read();
		if(ch == 'Q')
		{
			int fx = find(x);
			int ans = query(rt[fx],1,n+1,y);
			printf("%d\n",ans == n+1 ? -1 : ans);
		}
		else
		{
			int fx = find(x);
			int fy = find(y);
			if(fx == fy) continue;
			fa[fy] = fx;
			merage(rt[fx],rt[fy]);//合併
		}
	}
	return 0;
}
P4197 Peaks

這題需要好好的想一想。

好像聽別的巨佬說這似乎是克魯斯卡爾重構樹的板子題,但我太菜了,沒學會。

所以只能拿點簡單的做法水一下了。

考慮沒有邊權的情況就和上面永無鄉那道題是一樣的題。

有了邊權處理起來就比較麻煩。

我們可以考慮對詢問離線一下。

從困難值小的開始處理,每次把邊權比困難值小的邊加進去合併一下。

還是和上面那個題一樣的思路並查集維護聯通塊連通性,線段樹維護每個聯通塊的資訊。

並查集合並的時候順便把線段樹也合併一下,然後這題就做完了。

一個要注意的點就是在同一聯通塊之間的線段樹不要合併,否則複雜度會退化到 \(O(n^2)\) 的。

Code

#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
const int N = 1e5+10;
int n,m,cntq,tot,last = 1;
int h[N],rt[N],fa[N],b[N],ans[500010];
struct bian
{
	int u,v,w;
}e[500010];
struct node
{
	int x,v,k,id;
}q[500010];
struct Tree
{
	int lc,rc,sum;
}tr[N*20];
bool comp(bian a,bian b){ return a.w < b.w;}
bool cmp(node a,node b){ return a.v < b.v; }
inline int read()
{
	int s = 0,w = 1; char ch = getchar();
	while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
	while(ch >= '0' && ch <= '9'){s = s * 10 + ch - '0'; ch = getchar();}
	return s * w;
}
int find(int x)
{
	if(fa[x] == x) return x;
	else return fa[x] = find(fa[x]);
}
void insert(int &p,int l,int r,int x,int val)
{
	if(!p) p = ++tot;
	tr[p].sum += val;
	if(l == r) return;
	int mid = (l + r)>>1;
	if(x <= mid) insert(tr[p].lc,l,mid,x,val);
	if(x > mid) insert(tr[p].rc,mid+1,r,x,val);
	tr[p].sum = tr[tr[p].lc].sum + tr[tr[p].rc].sum;
}
int query(int o,int l,int r,int k)//線段樹二分求區間第k大,優先遞迴右子樹
{
	if(l == r) return l;
	if(tr[o].sum < k) return -1;
	int mid = (l + r)>>1;
	if(tr[tr[o].rc].sum >= k) return query(tr[o].rc,mid+1,r,k);
	else return query(tr[o].lc,l,mid,k-tr[tr[o].rc].sum);
}
void merage(int &x,int y)
{
	if(!x){x = y; return;}
	if(!y) return;
	tr[x].sum += tr[y].sum;
	merage(tr[x].lc,tr[y].lc);
	merage(tr[x].rc,tr[y].rc);
}
int main()
{
	n = read(); m = read(); cntq = read();
	for(int i = 1; i <= n; i++) b[i] = h[i] = read();
	sort(b+1,b+n+1);
	int num = unique(b+1,b+n+1)-b-1;
	for(int i = 1; i <= n; i++) fa[i] = i;
	for(int i = 1; i <= n; i++) h[i] = lower_bound(b+1,b+num+1,h[i])-b;
	for(int i = 1; i <= m; i++)
	{
		e[i].u = read();
		e[i].v = read();
		e[i].w = read();
	}
	sort(e+1,e+m+1,comp);
	for(int i = 1; i <= cntq; i++)
	{
		q[i].x = read();
		q[i].v = read();
		q[i].k = read();
		q[i].id = i;
	}
	sort(q+1,q+cntq+1,cmp);//離線一下
	for(int i = 1; i <= n; i++)
	{
		insert(rt[i],1,n,h[i],1);
	}
	for(int i = 1; i <= cntq; i++)
	{
		while(last <= m && e[last].w <= q[i].v)
		{
			int fx = find(e[last].u);
			int fy = find(e[last].v);
			if(fx == fy){ last++; continue;}
			fa[fy] = fx;
			merage(rt[fx],rt[fy]);//線段樹並查集合並
			last++;
		} 
		int fx = find(q[i].x);
		int res = query(rt[fx],1,n,q[i].k);
		ans[q[i].id] = res == -1 ? -1 : b[res];
	}
	for(int i = 1; i <= cntq; i++) printf("%d\n",ans[i]);
	return 0;
}

課後練習題:

1.CF1009F Dominant Indices 線段樹合併優化dp

2.CF490F Treeland Tour 線段樹合併優化dp