1. 程式人生 > >(splay)區間第k大

(splay)區間第k大

https://www.luogu.org/problemnew/show/P1801

#include<bits/stdc++.h>
using namespace std;
const int inf = 2e9 + 50;

class node
{
public:
        node* ch[2], *fa;
        int val;
        int size;
        int recy;
        node(int x)
        {
            ch[0] = ch[1] = fa = NULL;
            val = x;size = 1;recy = 1;
        }
};
inline void push_up(node *v)
{
    v->size = v->recy + (v->ch[0] ? v->ch[0]->size : 0) + (v->ch[1] ? v->ch[1]->size : 0);
}
inline void attach(node *p, node *s, int x)
{
      p->ch[x] = s;
      if(s) s->fa = p;
}

class Splay//儲存規則:小左大右,重複節點記錄
{
    node* root;
    void rotate(node *v)
    {
        if(v == root) return ;
        node* p = v->fa;
        int flag = p->ch[1] == v;
        if(p->fa)   attach(p->fa, v, p->fa->ch[1] == p);
        else   v->fa = NULL, root = v;
        attach(p, v->ch[flag ^ 1], flag);
        attach(v, p, flag ^ 1);
        push_up(p);
        push_up(v);
    }
    public:
    void init() { root = NULL; }
    node* GetRoot() { return root; }
    node* splay(node *v)
    {
        for(node *p; p = v->fa; rotate(v))
        {
            if(p->fa) rotate((p->fa->ch[0] == p) == (p->ch[0] == v) ? p : v);
        }
        return root = v;
    }
    node* search(int x)//查詢值為x的節點 沒找到返回值最接近的節點
    {
        node* p = root;
        while(p)
        {
            if(p->val == x) break;
            if(p->ch[p->val < x]) p = p->ch[p->val < x];
            else break;
        }
        splay(p);
        return p;
    }
    node* insert(int x)//插入一個值為x的節點
    {
        if(!root) return root = new node(x);
        node* p = search(x);
        if(p->val == x)
        {
            p->recy++, p->size++;
            return p;
        }
        node* v = new node(x);
        int flag = p->val > x;
        attach(v, p, flag);
        attach(v, p->ch[flag ^ 1], flag ^ 1);
        p->ch[flag ^ 1] = NULL;
        v->fa = NULL;
        push_up(p);push_up(v);
        return root = v;
    }
    bool erase(int x) //刪除值為x的節點
    {
        node* p = search(x);
        if(!p || p->val != x)
            return false;
        if(p->recy > 1)
            p->recy--, p->size--;
        else if(!p->ch[0] && !p->ch[1])
            root = NULL;
        else if(!p->ch[0])
        {
            root = p->ch[1];
            root->fa = NULL;
            delete p;
        }
        else if(!p->ch[1])
        {
            root = p->ch[0];
            root->fa = NULL;
            delete p;
        }
        else
        {
            node* tmp = root->ch[0];
            tmp->fa = NULL;
            root->ch[0] = NULL;
            root = root->ch[1];
            root->fa = NULL;
            search(p->val);
            root->ch[0] = tmp;
            tmp->fa = root;
            delete p;
        }
        if(root) push_up(root);
        return true;
    }
    int rank(int x)//返回x的排名 從1開始,重複按第一個算
    {
        search(x);
        int res = root->ch[0] ? root->ch[0]->size + 1: 1;
        return res + (root->val < x ? root->recy : 0);
    }
    int find_rank(int x)//查詢排名x的值
    {
        if(x <= 0) return -inf;
        node* p = root;
        while(p)
        {
            if(p->ch[0] && p->ch[0]->size >= x)
                p = p->ch[0];
            else if((p->ch[0] ? p->ch[0]->size : 0) + p->recy >= x)
            {
                splay(p);
                return p->val;
            }
            else
            {
                x -= (p->ch[0] ? p->ch[0]->size : 0) + p->recy;
                p = p->ch[1];
            }
        }
        return inf;
    }
} Splay;
int a[200005];
int main()
{
    ios_base::sync_with_stdio(0);
    cin.tie(0); cout.tie(0);

    Splay.init();
    int n, m, x = 1, val;
    cin >> n >> m;
    for(int i = 1; i <= n; i++) cin >> a[i];
    for(int i = 1; i <= m; i++) {
        cin >> val;
        while(x <= val)
            Splay.insert(a[x++]);
        cout << Splay.find_rank(i) << endl;
    }
    return 0;
}