1. 程式人生 > >二逼平衡樹(樹套樹)

二逼平衡樹(樹套樹)

傳送門

這道題的做法……我學的是最經典的線段樹套平衡樹。

因為發現其實這題的題目描述和普通平衡樹非常的相似……只是這次是在給定的區間中。所以我們能想象到用線段樹維護區間,然後每個線段樹的節點都是一顆平衡樹,用於維護區間內資訊。

具體操作的實現辦法:
1.查詢k在區間內的排名:在給定的區間的每一個平衡樹上求k的排名,其和即為答案。
2.查詢區間內排名為k的數:這個操作是不能線上段樹上疊加的,所以我們需要二分答案轉化為判定類問題,就是轉化為問題一。
3.修改某一位置上數值:在給定區間內的所有平衡樹上找到這個數並且修改。
4.查詢k在區間內前驅:在給定區間所有平衡樹內查k的前驅,取最大值。
5.查詢k在區間內後繼:在給定區間所有平衡樹內查k後繼,取最小值。

以上操作除了操作2需要二分答案,複雜度是\(O(log^3n)\),剩下的都是\(O(log^2n)\)的。

然後這個具體的實現方法很複雜……其實平衡樹內部和線段樹內部的操作和普通的方法基本都是大同小異的。不同的在於插入,刪除節點以及新建節點。具體的思路其實和平衡樹也很像……不過轉移到線段樹上比較複雜,但是看看程式碼都能看懂。

注意這題要垃圾回收。還有就是樹套樹真的好長……還挺容易寫錯的……這玩意也是咋寫都要5k+……

#include<iostream>
#include<cstdio>
#include<cmath>
#include<algorithm>
#include<queue>
#include<cstring>
#define rep(i,a,n) for(register int i = a;i <= n;i++)
#define per(i,n,a) for(register int i = n;i >= a;i--)
#define enter putchar('\n')
#define pr pair<int,int>
#define mp make_pair
#define fi first
#define sc second
#define I inline
#define get(x) (t[t[x].fa].ch[1] == (x))
using namespace std;
typedef long long ll;
const int M = 1000005;
const int N = 10000005;
const int INF = 2147483647;

int read()
{
   int ans = 0,op = 1;char ch = getchar();
   while(ch < '0' || ch > '9') {if(ch == '-') op = -1;ch = getchar();}
   while(ch >='0' && ch <= '9') ans = ans * 10 + ch - '0',ch = getchar();
   return ans * op;
}

struct tree
{
   int ch[2],val,size,cnt,fa;
}t[M];

int n,m,a[M],root[M],idx,bin[M],btop,op,l,r,x,y,z;

I int newnode(int x)
{
   int u = btop ? bin[btop--] : ++idx;
   t[u].val = x,t[u].cnt = t[u].size = 1,t[u].fa = t[u].ch[0] = t[u].ch[1] = 0;
   return u;
}

I void update(int u)
{
   t[u].size = t[t[u].ch[0]].size + t[t[u].ch[1]].size + t[u].cnt;
}

I void rotate(int x)
{
   int y = t[x].fa,z = t[y].fa,k = get(x);
   if(z) t[z].ch[get(y)] = x;
   t[x].fa = z,t[y].ch[k] = t[x].ch[k^1],t[t[y].ch[k]].fa = y;
   t[x].ch[k^1] = y,t[y].fa = x;
   update(y),update(x);
}

I void splay(int x)
{
   while(t[x].fa)
   {
      int y = t[x].fa,z = t[y].fa;
      if(z) ((t[y].ch[0] == x) ^ (t[z].ch[0] == y)) ? rotate(x) : rotate(y);
      rotate(x);
   }
   update(x);
}

I int getnum(int k,int x)
{
   int u = root[k],v = 0;
   while(u)
   {
      v = u;
      if(t[u].val == x) return u;
      if(x < t[u].val) u = t[u].ch[0];
      else u = t[u].ch[1];
   }
   return v;
}

I int getkth(int k,int x)
{
   int u = root[k];
   while(u)
   {
      if(x <= t[t[u].ch[0]].size) u = t[u].ch[0];
      else if(x > t[t[u].ch[0]].size + t[u].cnt) x -= (t[t[u].ch[0]].size + t[u].cnt),u = t[u].ch[1];
      else return u;
   }
   return 0;
}

I int getless(int k,int x)
{
   int u = root[k],cur = 0;
   while(u)
   {
      if(t[u].val < x) cur += t[u].cnt + t[t[u].ch[0]].size,u = t[u].ch[1];
      else u = t[u].ch[0];
   }
   return cur;
}

I int getmax(int u)
{
   while(t[u].ch[1]) u = t[u].ch[1];
   return u;           
}

I int getmin(int u)
{
   while(t[u].ch[0]) u = t[u].ch[0];
   return u;
}

I int getpre(int k,int x)
{
   int u = getnum(k,x);
   if(!u) return -INF;
   splay(u),root[k] = u;
   if(t[u].val >= x) u = getmax(t[u].ch[0]);
   return u ? t[u].val : -INF;
}

I int getnext(int k,int x)
{
   int u = getnum(k,x);
   if(!u) return INF;
   splay(u),root[k] = u;
   if(t[u].val <= x) u = getmin(t[u].ch[1]);
   return u ? t[u].val : INF;
}

I void insert(int k,int x)
{
   int u = getnum(k,x);
   if(t[u].val == x)
   {
      splay(u),root[k] = u;
      t[u].cnt++,t[u].size++;
      return;
   }
   u = newnode(x);
   if(!root[k]) {root[k] = u;return;}
   int v = root[k],w = 0,dir = 0;
   while(v)
   {
      w = v;
      if(t[u].val <= t[v].val) dir = 1,v = t[v].ch[0];
      else dir = 0,v = t[v].ch[1];
   }
   if(dir) t[w].ch[0] = u;
   else t[w].ch[1] = u;
   t[u].fa = w,splay(u),root[k] = u;
}

I void del(int k,int x)
{
   int u = getnum(k,x);
   splay(u),root[k] = u;
   if(t[u].cnt > 1) {t[u].cnt--,t[u].size--;return;}
   if(t[u].size == 1) root[k] = 0;
   else if(!t[u].ch[0] || !t[u].ch[1])
   {
      root[k] = t[u].ch[0] | t[u].ch[1];
      t[root[k]].fa = 0;
   }
   else
   {
      t[t[u].ch[0]].fa = 0;
      int v = getmax(t[u].ch[0]);
      splay(v),root[k] = v;
      t[v].ch[1] = t[u].ch[1],t[t[u].ch[1]].fa = v,update(v);
   }
   bin[++btop] = u;
}

void segbuild(int p,int l,int r)
{
   rep(i,l,r) insert(p,a[i]);
   if(l == r) return;
   int mid = (l+r) >> 1;
   segbuild(p<<1,l,mid),segbuild(p<<1|1,mid+1,r);
}

void segchange(int p,int l,int r,int k,int val)
{
   del(p,a[k]),insert(p,val);
   if(l == r) return;
   int mid = (l+r) >> 1;
   if(k <= mid) segchange(p<<1,l,mid,k,val);
   else segchange(p<<1|1,mid+1,r,k,val);
}

int segless(int p,int l,int r,int kl,int kr,int val)
{
   if(l == kl && r == kr) return getless(p,val);
   int mid = (l+r) >> 1;
   if(kr <= mid) return segless(p<<1,l,mid,kl,kr,val);
   else if(kl > mid) return segless(p<<1|1,mid+1,r,kl,kr,val);
   else return segless(p<<1,l,mid,kl,mid,val) + segless(p<<1|1,mid+1,r,mid+1,kr,val);
}

int segpre(int p,int l,int r,int kl,int kr,int val)
{
   if(l == kl && r == kr) return getpre(p,val);
   int mid = (l+r) >> 1;
   if(kr <= mid) return segpre(p<<1,l,mid,kl,kr,val);
   else if(kl > mid) return segpre(p<<1|1,mid+1,r,kl,kr,val);
   else return max(segpre(p<<1,l,mid,kl,mid,val),segpre(p<<1|1,mid+1,r,mid+1,kr,val));
}

int segnext(int p,int l,int r,int kl,int kr,int val)
{
   if(l == kl && r == kr) return getnext(p,val);
   int mid = (l+r) >> 1;
   if(kr <= mid) return segnext(p<<1,l,mid,kl,kr,val);
   else if(kl > mid) return segnext(p<<1|1,mid+1,r,kl,kr,val);
   else return min(segnext(p<<1,l,mid,kl,mid,val),segnext(p<<1|1,mid+1,r,mid+1,kr,val));
}

int segkth(int kl,int kr,int k)
{
   int L = 0,R = 100000000;
   while(L < R)
   {
      int mid = (L+R+1) >> 1;
      if(segless(1,1,n,kl,kr,mid) > k-1) R = mid - 1;
      else L = mid;
   }
   return L;
}

int main()
{
   n = read(),m = read(),t[0].val = -1;
   rep(i,1,n) a[i] = read();
   segbuild(1,1,n);
   while(m--)
   {
      op = read();
      if(op == 3) x = read(),y = read(),segchange(1,1,n,x,y),a[x] = y;
      else l = read(),r = read(),x = read();
      if(op == 1) printf("%d\n",segless(1,1,n,l,r,x) + 1);
      if(op == 2) printf("%d\n",segkth(l,r,x));
      if(op == 4) printf("%d\n",segpre(1,1,n,l,r,x));
      if(op == 5) printf("%d\n",segnext(1,1,n,l,r,x));
   }
   return 0;
}