1. 程式人生 > 實用技巧 >演算法初探 - 平衡樹

演算法初探 - 平衡樹

更新記錄

【1】2020.08.14-17:01

  • 1.完善Splay內容

正文

Splay

Splay是一種沒有用隨機數函式的平衡樹,它依靠伸展操作來維持平衡

不過也正是這樣,導致其能維護數列的某些特殊的區間操作,例如區間反轉

直接來說Splay的核心:rotate與splay操作

rotate就是旋轉,又稱之為上旋

就是將一個非根結點向上旋轉,旋轉前後都滿足二叉搜尋樹的性質

那麼易得一個左兒子n旋轉到father的位置:

  • 其左兒子比n小,不動
  • 其右兒子比n大,但是比father小,連線到father的左兒子上

看個例子,考慮一棵完全二叉搜尋樹:

將其中結點2旋轉到根上

按照之前的結論,將2轉到根上,然後將3連線到5上

最後就是這樣:

很簡單,對吧?

所以我們很容易的就寫出了rotate的程式碼

inline void rotate(int n){
	int fa=t[n].fa,grfa=t[t[n].fa].fa;
	bool np=confirm(n),fap=confirm(t[n].fa);
	connect(fa,t[n].son[np^1],np);
	connect(n,fa,np^1);
	connect(grfa,n,fap);
	update(fa);update(n);
}

在這裡可能你不知道某些函式是幹嘛的
沒關係,這些函式因為太簡單了被我放在了後面統一說明

splay就是多個rotate集合,用來將一個結點旋轉到根上

考慮結點n,其父親fa,父親的父親grfa

  • 如果三點一線:先旋轉fa,再旋轉n
  • 其他情況:旋轉兩次n

對於三點一線的情況,如果旋轉兩次n,極易出現鏈的情況且易被卡,考慮一條鏈即可

所以我們依然能夠非常容易的寫出程式碼

inline void splay(int n){
	while(t[n].fa){
		if(t[t[n].fa].fa)
			if(confirm(t[n].fa)==confirm(n)) rotate(t[n].fa);
			else rotate(n);
		rotate(n);
	}
	root=n;
}

好了這就是splay的核心操作了,是不是非常簡單呢?

接下來是其他非主要函式的講解

connect連邊函式,將兩個點連在一起

inline void connect(int up,int down,bool r){
	if(up) t[up].son[r]=down;
	t[down].fa=up;
}

update更新以此結點為根的子樹的大小

inline void update(int n){
	if(n){
		t[n].size=t[n].num;
		if(t[n].son[0]) t[n].size+=t[t[n].son[0]].size;
		if(t[n].son[1]) t[n].size+=t[t[n].son[1]].size;
	}
}

confirm函式用來確認這個結點是父結點的左兒子還是右兒子

inline bool confirm(int n){
	return t[t[n].fa].son[1]==n;
}

clear用來清除一個結點的全部資訊

inline void clear(int n){
	t[n].fa=t[n].num=t[n].size=t[n].v=t[n].son[0]=t[n].son[1]=0;
}

insert插入結點

  • 當樹中沒有結點時:新建結點
  • 當樹中有結點的時候
    • 當要插入的結點已經存在,直接num+1即可
    • 不存在時,新建結點

具體細節:

inline void insert(int n){
	if(!root){
		root=++node;t[root].v=n;
		t[root].num=t[root].size=1;
		return;
	}
	int nd=root,fa=0;
	while(1){
		if(t[nd].v==n){
			t[nd].num+=1;
			update(nd);update(fa);
//更新
			splay(nd);
//旋轉
			return;
		}
		fa=nd,nd=t[nd].son[n>t[nd].v];
		if(!nd){
			node+=1;t[node].v=n;
			t[node].num=t[node].size=1;
			t[node].fa=fa;
			t[fa].son[n>t[fa].v]=node;
//n>t[fa].v是根據二叉搜尋樹的性質來確定這個結點在哪裡
			update(fa);splay(node);
			return;
		}
	}
}

numrank用來檢視一個數的排名

依然考慮二叉搜尋樹的性質:

  • 如果這個數比這個結點小,那麼搜尋左子樹
  • 如果這個數和這個結點一樣大,那麼它肯定比左子樹的所有節點都大,sum加上左子樹的size,之後+1就是排名
  • 如果這個數比這個結點大,那麼sum加上左子樹的size,加上這個結點的num,之後搜尋右子樹
inline int numrank(int n){
	int nd=root,sum=0;
	while(1){
		if(n<t[nd].v){
			nd=t[nd].son[0];
			continue;
		}
		sum+=t[t[nd].son[0]].size;
		if(n==t[nd].v){
			splay(nd);
			return sum+1;
		}
		sum+=t[nd].num;
		nd=t[nd].son[1];
	}
}

ranknum用來查詢處於這個排名的數

  • 如果現在排名等於小於左子樹的結點的個數,那麼搜尋左子樹
  • 否則,如果現在排名小於左子樹的結點的個數+這個結點的num,那麼就是這個結點
  • 否則,排名減去左子樹的size+結點的num,搜尋右子樹
inline int ranknum(int n){
	int nd=root;
	while(1){
		if(n<=t[t[nd].son[0]].size){
			nd=t[nd].son[0];
			continue;
		}
		else if(n<=t[t[nd].son[0]].size+t[nd].num){
			splay(nd);
			return t[nd].v;
		}
		else{
			n-=t[t[nd].son[0]].size+t[nd].num;
			nd=t[nd].son[1];
		}
	}
}

presuf用來求前趨與後繼,本來是兩個函式被我合併成了一個

以查前趨為例,插入要查詢的n
此時n是根結點,前趨肯定是比n小
於是乎我們從左子樹瘋狂向右找
最後找到,之後刪除n

inline int presuf(int n,bool y,bool r){
	if(y) insert(n);
	int nd=t[root].son[r];
	while(t[nd].son[r^1]) nd=t[nd].son[r^1];
	if(y) del(n);
	return nd;
}

del就是刪除結點了

先把n轉上來

  • 如果num>1那麼直接-1就行了
  • 否則,如果沒兒子,那說明就這一個結點,清空
  • 然後哪邊缺root是哪邊
  • 之後如果都有就轉一下,連個邊
inline void del(int n){
	numrank(n);
	int cp=root;bool more=1;
	if(t[root].num>1){
		more=0;t[root].num-=1;
		update(root);
	}
	else if(!t[root].son[0]&&!t[root].son[1]) root=0;
	else if(!t[root].son[0]){
		root=t[root].son[1];
		t[root].fa=0;
	}
	else if(!t[root].son[1]){
		root=t[root].son[0];
		t[root].fa=0;
	}
	else{
		splay(presuf(root,0,0));
		connect(root,t[cp].son[1],1);
	}
	if(more) clear(cp);
	update(root);
}

完整Splay程式碼:

#include<cstdio>
#include<iostream>
#define N 1000100
struct baltree{
	int fa,v,size,num,son[2];
}t[N];
int root,node,n,a,b;
inline void del(int n);
inline int read(){
	int sum=0,chs=1;char c=getchar();
	while(!isdigit(c)){
		if(c=='-') chs=-1;
		c=getchar();
	}
	while(isdigit(c)){
		sum=(sum<<1)+(sum<<3)+c-48;
		c=getchar();
	}
	return sum*chs;
}
inline void clear(int n){
	t[n].fa=t[n].num=t[n].size=t[n].v=t[n].son[0]=t[n].son[1]=0;
}
inline bool confirm(int n){
	return t[t[n].fa].son[1]==n;
}
inline void update(int n){
	if(n){
		t[n].size=t[n].num;
		if(t[n].son[0]) t[n].size+=t[t[n].son[0]].size;
		if(t[n].son[1]) t[n].size+=t[t[n].son[1]].size;
	}
}
inline void connect(int up,int down,bool r){
	if(up) t[up].son[r]=down;
	t[down].fa=up;
}
inline void rotate(int n){
	int fa=t[n].fa,grfa=t[t[n].fa].fa;
	bool np=confirm(n),fap=confirm(t[n].fa);
	connect(fa,t[n].son[np^1],np);
	connect(n,fa,np^1);
	connect(grfa,n,fap);
	update(fa);update(n);
}
inline void splay(int n){
	while(t[n].fa){
		if(t[t[n].fa].fa)
			if(confirm(t[n].fa)==confirm(n)) rotate(t[n].fa);
			else rotate(n);
		rotate(n);
	}
	root=n;
}
inline void insert(int n){
	if(!root){
		root=++node;t[root].v=n;
		t[root].num=t[root].size=1;
		return;
	}
	int nd=root,fa=0;
	while(1){
		if(t[nd].v==n){
			t[nd].num+=1;
			update(nd);update(fa);
			splay(nd);
			return;
		}
		fa=nd,nd=t[nd].son[n>t[nd].v];
		if(!nd){
			node+=1;t[node].v=n;
			t[node].num=t[node].size=1;
			t[node].fa=fa;
			t[fa].son[n>t[fa].v]=node;
			update(fa);splay(node);
			return;
		}
	}
}
inline int numrank(int n){
	int nd=root,sum=0;
	while(1){
		if(n<t[nd].v){
			nd=t[nd].son[0];
			continue;
		}
		sum+=t[t[nd].son[0]].size;
		if(n==t[nd].v){
			splay(nd);
			return sum+1;
		}
		sum+=t[nd].num;
		nd=t[nd].son[1];
	}
}
inline int ranknum(int n){
	int nd=root;
	while(1){
		if(n<=t[t[nd].son[0]].size){
			nd=t[nd].son[0];
			continue;
		}
		else if(n<=t[t[nd].son[0]].size+t[nd].num){
			splay(nd);
			return t[nd].v;
		}
		else{
			n-=t[t[nd].son[0]].size+t[nd].num;
			nd=t[nd].son[1];
		}
	}
}
inline int presuf(int n,bool y,bool r){
	if(y) insert(n);
	int nd=t[root].son[r];
	while(t[nd].son[r^1]) nd=t[nd].son[r^1];
	if(y) del(n);
	return nd;
}
inline void del(int n){
	numrank(n);
	int cp=root;bool more=1;
	if(t[root].num>1){
		more=0;t[root].num-=1;
		update(root);
	}
	else if(!t[root].son[0]&&!t[root].son[1]) root=0;
	else if(!t[root].son[0]){
		root=t[root].son[1];
		t[root].fa=0;
	}
	else if(!t[root].son[1]){
		root=t[root].son[0];
		t[root].fa=0;
	}
	else{
		splay(presuf(root,0,0));
		connect(root,t[cp].son[1],1);
	}
	if(more) clear(cp);
	update(root);
}
signed main(){
	n=read();
	while(n--){
		a=read(),b=read();
		if(a==1) insert(b);
		else if(a==2) del(b);
		else if(a==3) printf("%d\n",numrank(b));
		else if(a==4) printf("%d\n",ranknum(b));
		else if(a==5) printf("%d\n",t[presuf(b,1,0)].v);
		else printf("%d\n",t[presuf(b,1,1)].v);
	}
}

區間操作

例如反轉區間 \([l,r]\)

那麼將l的前趨轉到根結點,r的後繼轉到根結點右方

此時r的後繼的左方的子樹就是區間 \([l,r]\)

打上標記,之後看見標記反轉就可以啦

為什麼是這樣?

先來問兩個問題:

  1. Splay旋轉的特點?
  2. 歸併排序的思想?

那麼聰明的同學直接就能想到答案啦

這個子樹在根結點和其父結點不被操作的時候是不會改變的

操作的時候呢?
標記就下傳啦!!

歸併排序的思想就是先將大區間整體反轉,然後小區間反轉......

此時一定拋棄二叉搜尋樹的思想,此時的Splay維護的是區間!

#include<iostream>
#include<cstdio>
#define N 1000100
const int INF=2147483647;
using namespace std;
struct baltree{
	int son[2],v,fa,size,ret;
}t[N];
int root,node=0,n,m,a,b;
inline void pd(int n){
	if(t[n].ret){
		t[n].ret=0;
		t[t[n].son[0]].ret^=1;t[t[n].son[1]].ret^=1;
		swap(t[n].son[0],t[n].son[1]);
	}
}
inline void update(int n){
	if(n){
		t[n].size=1;
		if(t[n].son[0]) t[n].size+=t[t[n].son[0]].size;
		if(t[n].son[1]) t[n].size+=t[t[n].son[1]].size;
	}
}
inline bool confirm(int n){
	return t[t[n].fa].son[1]==n;
}
inline void connect(int up,int down,bool r){
	if(up) t[up].son[r]=down;
	t[down].fa=up;
}
inline void rotate(int n){
	int fa=t[n].fa,grfa=t[fa].fa;
	bool np=confirm(n),fap=confirm(t[n].fa);
	connect(fa,t[n].son[np^1],np);
	connect(n,fa,np^1);
	connect(grfa,n,fap);
	update(fa),update(n);
}
inline void splay(int n,int p){
	while(t[n].fa!=p){
		if(t[t[n].fa].fa!=p){
			if(confirm(t[n].fa)==confirm(n)) rotate(t[n].fa);
			else rotate(n);
		}
		rotate(n);
	}
	if(!p) root=n;
}
inline void insert(int v){
	if(!root){
		root=++node;
		t[root].v=v;
		t[root].size=1;
		return;
	}
	int n=root,fa=0;
	while(1){
		fa=n,n=t[n].son[v>t[n].v];
		if(!n){
			t[++node].v=v;
			t[node].size=1;
			t[node].fa=fa;
			t[fa].son[v>t[fa].v]=node;
			update(fa);splay(node,0);
			return;
		}
	}
}
inline int ranknum(int n){
	int nd=root;
	while(1){
		pd(nd);
		if(n<=t[t[nd].son[0]].size){
			nd=t[nd].son[0];
			continue;
		}
		else if(n<=t[t[nd].son[0]].size+1){
			splay(nd,0);
			return t[nd].v;
		}
		else{
			n-=t[t[nd].son[0]].size+1;
			nd=t[nd].son[1];
		}
	}
}
inline void reverse(int l,int r,int lth,int rth){
	splay(lth,0);splay(rth,root);
	t[t[rth].son[0]].ret^=1;
}
inline void outdata(int n){
	pd(n);
	if(t[n].son[0]) outdata(t[n].son[0]);
	if(t[n].v>1&&t[n].v<=(::n+1)) printf("%d ",t[n].v-1);
	if(t[n].son[1]) outdata(t[n].son[1]);
}
signed main(){
	cin>>n>>m;
	for(int i=1;i<=n+2;i++)
		insert(i);
	for(int i=0;i<m;i++){
		cin>>a>>b;
		reverse(a,b,ranknum(a),ranknum(b+2));
	}
	outdata(root);
}

看完了板子,來看例題吧

序列終結者
這就是板子的變形啊

#pragma gcc optimize(2)
#pragma gcc optimize(3)
#pragma gcc optimize(-Ofast)
#include<iostream>
#include<cstdio>
#define N 1001000
#define INF 0x3f3f3f3f
int n,m,node,rt,a,b,c,d,rth;
struct baltree{
	int size,son[2],fa,v,ret,add,maxn;
}t[N];
inline int max(int a,int b){return a>b?a:b;}
inline void pu(int n){
	t[n].size=1;
	if(t[n].son[0]) t[n].size+=t[t[n].son[0]].size;
	if(t[n].son[1]) t[n].size+=t[t[n].son[1]].size;
	t[n].maxn=t[n].v;
	if(t[n].son[0]) t[n].maxn=max(t[n].maxn,t[t[n].son[0]].maxn);
	if(t[n].son[1]) t[n].maxn=max(t[n].maxn,t[t[n].son[1]].maxn);
}
inline void pd(int n){
	if(t[n].add){
		t[t[n].son[0]].add+=t[n].add,
		t[t[n].son[0]].v+=t[n].add;
		t[t[n].son[1]].add+=t[n].add,
		t[t[n].son[1]].v+=t[n].add;
		t[t[n].son[0]].maxn+=t[n].add,
		t[t[n].son[1]].maxn+=t[n].add;		
	}
	if(t[n].ret){
		t[t[n].son[0]].ret^=1;
		t[t[n].son[1]].ret^=1;
		std::swap(t[n].son[0],t[n].son[1]);
	}
	t[n].ret=t[n].add=0;
}
inline bool confirm(int n){
	return t[t[n].fa].son[1]==n;
}
inline void connect(int up,int down,bool r){
	if(up) t[up].son[r]=down;
	t[down].fa=up;
}
inline void rotate(int n){
	int fa=t[n].fa,grfa=t[fa].fa;
	bool np=confirm(n),fap=confirm(fa);
	connect(fa,t[n].son[np^1],np);
	connect(n,fa,np^1);
	connect(grfa,n,fap);
	pd(n),pd(fa);
	pu(fa),pu(n);
}
inline void splay(int n,int p){
	while(t[n].fa!=p){
		if(t[t[n].fa].fa!=p){
			if(confirm(t[n].fa)==confirm(n)) rotate(t[n].fa);
			else rotate(n);
		}
		rotate(n);
	}
	if(!p) rt=n;
}
inline void insert(int v){
	if(!rt){
		rt=++node;
		t[rt].v=v;
		t[rt].maxn=v;
		t[rt].size=1;
		return;
	}
	int n=rt,fa=0;
	while(1){
		fa=n,n=t[n].son[v>t[n].v];
		if(!n){
			t[++node].v=v;
			t[node].size=1;
			t[node].fa=fa;
			t[node].maxn=v;
			t[fa].son[v>t[n].v]=node;
			pu(fa);splay(node,0);
			return;
		}
	}
}
inline int ranknode(int rank){
	int n=rt;
	while(1){
		pd(n);
		if(rank<=t[t[n].son[0]].size){
			n=t[n].son[0];
			continue;
		}
		else if(rank<=t[t[n].son[0]].size+1){
			splay(n,0);return n;
		}
		else{
			rank-=t[t[n].son[0]].size+1;
			n=t[n].son[1];
		}
	}
}
inline void reverse(int l,int r){
	t[t[rth].son[0]].ret^=1;
}
inline void add(int l,int r,int v){
	t[t[rth].son[0]].add+=v;
	t[t[rth].son[0]].v+=v;
	t[t[rth].son[0]].maxn+=v;
}
inline void querymax(int l,int r){
	printf("%d\n",t[t[rth].son[0]].maxn);
}
int main(){
	scanf("%d%d",&n,&m);
	for(register int i=0;i<n;++i)
		insert(0);
	insert(INF),insert(-INF);
	for(register int i=0;i<m;++i){
		scanf("%d%d%d",&a,&b,&c);
		rth=ranknode(c+2);
		splay(ranknode(b),0);splay(rth,rt);
		if(a==1){
			scanf("%d",&a);
			add(b,c,a);
		}
		else if(a==2) reverse(b,c);
		else if(a==3) querymax(b,c);
	}
}