洛谷 P3224 [HNOI2012]永無鄉(Splay合併)
阿新 • • 發佈:2021-11-18
傳送門
解題思路
若干平衡樹,每次操作有兩種,一是合併兩個Splay,二是查詢某一個點所在的平衡樹裡的第k小的點的編號。
首先用並查集維護某個點在哪個平衡樹裡,然後rt[i]記錄編號為i的平衡樹的根。
每次合併時啟發式合併,直接把小的樹的每個點暴力insert到大樹裡。
查詢正常操作即可。
為了方便可以把原來普通平衡樹的0節點,改為第i課樹的0節點是i,然後其他點的編號都加n,方便操作。
AC程式碼
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #include<algorithm> #include<vector> #include<queue> #include<map> #include<bitset> #include<stack> using namespace std; const int maxn=2e5+5; int fa[maxn],siz[maxn],n,m,rt[maxn],q; struct node{ int fa,son[2],val,siz; }tr[maxn]; inline int find(int x){ if(fa[x]==x) return x; return fa[x]=find(fa[x]); } void init(int i,int fa){ tr[i].fa=fa; tr[i].son[0]=tr[i].son[1]=0; tr[i].siz=1; } void update(int x){ tr[x].siz=1; if(tr[x].son[0]) tr[x].siz+=tr[tr[x].son[0]].siz; if(tr[x].son[1]) tr[x].siz+=tr[tr[x].son[1]].siz; } void rotate(int x){ int y=tr[x].fa,z=tr[y].fa; int c=(tr[y].son[1]==x); tr[tr[x].son[!c]].fa=y; tr[x].fa=tr[y].fa; tr[y].fa=x; if(z) tr[z].son[tr[z].son[1]==y]=x; tr[y].son[c]=tr[x].son[!c]; tr[x].son[!c]=y; update(y); update(x); } void splay(int x,int goal){ if(x==goal) return; while(tr[x].fa!=goal){ int y=tr[x].fa,z=tr[y].fa; if(z!=goal) ((tr[y].son[0]==x)^(tr[z].son[0]==y))?rotate(x):rotate(y); rotate(x); } if(goal<=n) rt[goal]=x; } void insert(int y,int id){ int x=rt[y]; while(1){ if(tr[x].son[tr[x].val<tr[id].val]) x=tr[x].son[tr[x].val<tr[id].val]; else{ init(id,x); tr[x].son[tr[x].val<tr[id].val]=id; splay(id,y); return; } } } void del(int x,int y){ if(tr[x].son[0]) del(tr[x].son[0],y); if(tr[x].son[1]) del(tr[x].son[1],y); insert(y,x); } void merge(int x,int y){ int fx=find(x),fy=find(y); if(fx==fy) return; if(siz[fx]>siz[fy]) swap(fx,fy); fa[fx]=fy; siz[fy]+=siz[fx]; del(rt[fx],fy); } int getval(int x,int k){ if(tr[x].siz<k) return -1; while(1){ if((tr[x].son[0]?tr[tr[x].son[0]].siz+1:1)==k){ return x-n; } if((tr[x].son[0]?tr[tr[x].son[0]].siz+1:1)<k){ k-=(tr[x].son[0]?tr[tr[x].son[0]].siz+1:1); x=tr[x].son[1]; }else{ x=tr[x].son[0]; } } } int main(){ ios::sync_with_stdio(false); cin>>n>>m; for(int i=1;i<=n;i++) cin>>tr[i+n].val,tr[i+n].fa=i,fa[i]=fa[i+n]=i,siz[i]=1,rt[i]=i+n,tr[i+n].siz=1; for(int i=1;i<=m;i++){ int u,v; cin>>u>>v; merge(u,v); } cin>>q; for(int i=1;i<=q;i++){ char c; int x,y; cin>>c>>x>>y; if(c=='Q'){ cout<<getval(rt[find(x)],y)<<endl; }else{ merge(find(x),find(y)); } } return 0; }