1. 程式人生 > 其它 >「學習筆記」平衡樹Splay

「學習筆記」平衡樹Splay

(1)Rotate(x)

\(x\) 向上移動一級,並將 \(fa[x]\) 作為兒子,保持 \(BST\) 的性質。

inline void Rotate(const int &x){//旋轉 
	int y=fa[x],z=fa[y]; 
	int b=(lc[y]==x)?rc[x]:lc[x];
	fa[x]=z,fa[y]=x;
	if(b)fa[b]=y;
	if(z)(y==lc[z]?lc[z]:rc[z])=x;
	if(x==lc[y])rc[x]=y,lc[y]=b;
	else lc[x]=y,rc[y]=b;
}

(2)Splay(x,target)

將伸展樹中節點 \(x\)

調整為 \(target\) 的兒子。是伸展樹的核心。

  • 如果爺爺是 \(target\),單旋一次
  • 如果爺爺不是 \(target\) 且父親和爺爺方向一致,先旋父親再旋自己
  • 如果爺爺不是 \(target\) 且父親和爺爺方向不同,旋兩次自己
inline bool Wrt(const int &x){//判斷x是否為父親的右兒子 
	return rc[fa[x]]==x;
}
inline void Splay(const int &x,const int &tar){//調整 
	while(fa[x]!=tar){
		if(fa[fa[x]]!=tar)
			Wrt(x)==Wrt(fa[x])?Rotate(fa[x]):Rotate(x);
		Rotate(x);
	}
	if(!tar)rt=x;
}

(3)Find(v)

inline int Find(int v){
	int x=rt;
	while(x){
		if(v==val[x])break;
		if(v> val[x])x=rc[x];
		else x=lc[x];
	}
	if(x)Splay(x,0);
	return x;
}

(4)Insert(v)

inline void Insert(int v){//插入 
	int x=rt,y=0,lr;
	while(x){
		++sz[y=x];
		if(v<val[x])lr=0,x=lc[x];
		else lc=1,x=rc[x];
	}
	x=++tot;
	fa[x]=y,val[x]=v,sz[x]=1;
	if(y)(lr==0?lc[y]:rc[y])=x;
	Splay(x,0);
}

(5)Join(u,v)

inline void Join(int u,int v){//合併 
	fa[u]=fa[v]=0;
	int w=u;
	while(rc[w])w=rc[w];
	Splay(w,0);
	rc[w]=v,fa[v]=w;
}

(6)Delete(u)

inline void Delete(int u){//刪除 
	Splay(u,0);
	if(!lc[u] || !rc[u])
		fa[rt=lc[u]+rc[u]]=0;
	else Join(lc[u],rc[u]);
	lc[u]=rc[u]=0;
}

(7)GetRank(v)

排名可表示為左子樹大小+1

inline int GetRank(int v){//查詢排名 
	int x=Find(v);
	return sz[lc[x]]+1;
}

(8)Spilt(v)

以元素 \(v\) 所在的節點 \(x\) 為界,將伸展樹分為左右兩顆

首先執行 Find(v) ,再 Splay(x,0) ,最後 \(x\) 的左子樹為 \(S1\) ,右子樹為 \(S2\)

其他

除了上述8種基本操作,伸展樹還支援求前驅、後繼、最值、排名第k的元素等操作

upd:2021.4.27 P3369 【模板】普通平衡樹

#include <bits/stdc++.h>
using namespace std;
inline int gin(){
	int s=0,f=1;
	char c=getchar();
	while(c<'0' || c>'9'){
		if(c=='-') f=-1;
		c=getchar();
	}
	while(c>='0'&&c<='9'){
		s=(s<<3)+(s<<1)+(c^48);
		c=getchar();
	}
	return s*f;
}

const int N=1e6+5;
int n,fa[N],lc[N],rc[N],val[N],sz[N],rt,tot,cnt[N];

inline void push(int x){
	if(x){
		sz[x]=cnt[x];
		if(lc[x]) sz[x]+=sz[lc[x]];
		if(rc[x]) sz[x]+=sz[rc[x]];
	}
	return;
}
inline bool get(int x){
	return x==rc[fa[x]];
}
inline void clear(int x){
	lc[x]=rc[x]=fa[x]=val[x]=sz[x]=cnt[x]=0;
}

void rotate(int x){
	int y=fa[x],z=fa[y];
	int b=lc[y]==x ? rc[x] : lc[x];
	fa[x]=z,fa[y]=x;
	if(b) fa[b]=y;
	if(z) (lc[z]==y ? lc[z] : rc[z])=x;
	if(x==lc[y]) rc[x]=y,lc[y]=b;
	else lc[x]=y,rc[y]=b;
	push(x);
	push(y);
}

void splay(int x,int tar=0){
	while(fa[x]!=tar){
		if(fa[fa[x]]!=tar)
			get(x)==get(fa[x]) ? rotate(fa[x]) : rotate(x);
		rotate(x);
	}
	if(!tar) rt=x;
}

inline int find(int v){
	int x=rt;
	while(x){
		if(val[x]==v) break;
		if(val[x]> v) x=lc[x];
		else x=rc[x];
	}
	if(x) splay(x);
	return x;
}

inline void ins(int v){
	int w=find(v);
	if(w){
		cnt[w]++;
		return;
	}
	int x=rt,y=0,dir;
	while(x){
		++sz[y=x];
		if(v<val[x]) x=lc[x],dir=0;
		else x=rc[x],dir=1;
	}
	x=++tot;
	fa[x]=y,val[x]=v,sz[x]=1,cnt[x]=1;
	if(y) (dir==0 ? lc[y] : rc[y])=x;
	push(y);
	splay(x);
}

inline int rk(int v){
	int x=find(v);
	return sz[lc[x]]+1;
}

inline int kth(int k){
	int x=rt,y=k;
	while(x){
		if(y<=sz[lc[x]])
			x=lc[x];
		else {
			y-=sz[lc[x]]+cnt[x];
			if(y<=0){
				splay(x);
				return val[x];
			}
			x=rc[x];
		}
	}
}

inline void join(int u,int v){
	fa[u]=fa[v]=0;
	int w=u;
	while(rc[w]) w=rc[w];
	splay(w);
	rc[w]=v,fa[v]=w;
}

inline void del(int v){
	int x=find(v);
	if(cnt[x]>=2){
		cnt[x]--;
		push(x);
		return;
	}
	if(!lc[x] || !rc[x])
		fa[rt=lc[x]|rc[x]]=0;
	else join(lc[x],rc[x]);
}

inline int pre(){
	int x=lc[rt];
	while(rc[x]) x=rc[x];
	splay(x);
	return val[x];
}

inline int nxt(){
	int x=rc[rt];
	while(lc[x]) x=lc[x];
	splay(x);
	return val[x];
}

int main(){
	n=gin();
	for(int i=1;i<=n;i++){
		int op=gin(),x=gin();
		switch(op){
			case 1:ins(x);break;
			case 2:del(x);break;
			case 3:printf("%d\n",rk(x));break;
			case 4:printf("%d\n",kth(x));break;
			case 5:ins(x);printf("%d\n",pre());del(x);break;
			case 6:ins(x);printf("%d\n",nxt());del(x);break;
		}
	}
	return 0;
}