1. 程式人生 > >【線段樹】基本寫法,區間極值,區間延遲更新,多延遲標籤

【線段樹】基本寫法,區間極值,區間延遲更新,多延遲標籤

線段樹的特性:

開多大陣列來存線段樹?

線段樹表示的區間長度為k, 則葉節點個數k,線段樹作為一棵 完全二叉樹,它的總節點個數為n=2*k-1。  並且: 線段樹的左右子樹 葉節點個數相差不超過1.  假定線段樹的深度為h。則深度h對應的滿二叉樹節點個數為2^h-1. 下面證明: 2^h - 1 >= n > (2^(h-1) -1), 可以用反證法證明 首先,證明引理1:隨著線段樹的節點個數 2*k-1 增大,  深度不會出現反而減小的情況。 歸納法即可: k=1: h = 1, k= 2: h=2. .... :  然後,有 當線段樹節點個數 n= 2^(h-1)-1時,深度為h-1; 當 n= 2^h-1時,深度為h。 現在給定深度為h,則根據引理,必有 2^h - 1 >= n > (2^(h-1) -1) 那麼線段樹如果用陣列來存,陣列實際上存的是一棵滿二叉樹。那麼陣列的容量C=2^h-1<=2*n-1= 4*k-3
    於是,陣列存線段樹,它的容量為 4*k - 3, k 為區間容量。

更新: 

延遲更新(不一下子更新到葉節點,而是先更新到儘量少的幾個線段節點)

查詢:

  • 每次查詢到一個節點,它祖宗不能有延遲的更新未傳播下來。否則當前節點讀到的就是一個無效(未完全更新)資料了。
  • 訪問(查詢或更新)一個點之前,如果該點存在未下達的更新,就要先下達更新。(down)。 

         這樣每到一個點,路徑上面的祖宗點的更新都已經下達了下來。


struct node{
  int value;//統計的量
  int delta;//更新的量
}
struct tree{
  node nodes[MAXN];
  bool uTag;//
  void down(int ndx, int l, int r){
        node &d=nodes[ndx];
        if(!d.delta) return; //不需要向下更新
        if(l==r)  doSomething; //L==r是葉節點了
        nodes[2*ndx+1], nodes[2*n+2]   <=   d.delta //更新子節點
        d.delta=0; //不再更新
   }
   void set(int ndx, int l, int r, int delta, int s, int t){
        down(ndx, l, r)

        node &d=nodes[ndx];
        if(s<=l && t>=r)  	{ d.value+=delta, d.delta=delta; return;}

        //拆分[l, r]
        int mid=( l+r ) /2
        set(2*ndx+1, l, mid, delta, s, mid)  set(2*ndx+2, mid+1, r, delta, mid+1, t)
        
        //merge....
   }
   void get(int idx, int l, int r, int s, int t){
        down(idx, l, r)

        node &d=nodes[idx];
        if(s<=l && t>=r)  return xxx(d.value)

        //拆分[l, r]
        int mid=( l+r ) /2
        s1=get(2*idx+1, l, mid, d, s, mid)
        s2=get(2*idx+2, mid+1, r, d, mid+1, t)
        return xxx(s1, s2)
   }
};

例如進行

區間極值1

(關鍵在於更新的方向: 對於區間最小值這種, 先找到某個葉子節點,然後自底向上更新。而對於一次更新一個區間的求和性線段樹,則自頂向下更新。)

#include 
using namespace std;
#define MAXN 100000
int node[MAXN];
int value[MAXN];
int find(int l, int r, int idx, int s){//找到s對應葉節點的下標
    if(l==r) return idx;
    int mid=(l+r)/2;
    if(s>mid) return find(mid+1, r, idx*2+2, s);
    return find(l, mid, idx*2+1, s);
}
int set (int l, int r, int idx, int s, int t){//自底向上,更新區間中某個值
    value[s]=t;
    int n=find(l, r, idx, s);
    node[n]=s;//@error: this is initial value of all leaf node
    //update from bottom to top
    while(n>0){
        int p=(n-1)/2, q=4*p+3-n;
        node[p]=(value[node[q]]=r)){
        int mid=(l+r)/2;
        if(a<=mid && b>mid){
            int l1=get(l, mid, idx*2+1, a, mid);
            int l2=get(mid+1, r, idx*2+2, mid+1, b);
            return min(l1, l2);
        } 
        else{
            if(a>mid) l=mid+1, idx=idx*2+2;
            else r=mid, idx=idx*2+1;
        }
    }
    //cout<>n;
    for(int i=0; i>s;
        set(0, n-1, 0, i, s);
    }
    int m; cin>>m;
    for(int i=0; i>l>>a>>b; 
        //a--, b--;//@error
        if(l==0){
            a--,b--;
            cout<
另一個模板:

區間極值2

每個測試點(輸入檔案)有且僅有一組測試資料。

每組測試資料的第1行為一個整數N,意義如前文所述。

每組測試資料的第2行為N個整數,分別描述每種商品的重量,其中第i個整數表示標號為i的商品的重量weight_i。

每組測試資料的第3行為一個整數Q,表示小Hi總共詢問的次數與商品的重量被更改的次數之和。

每組測試資料的第N+4~N+Q+3行,每行分別描述一次操作,每行的開頭均為一個屬於0或1的數字,分別表示該行描述一個詢問和描述一次商品的重量的更改兩種情況。對於第N+i+3行,如果該行描述一個詢問,則接下來為兩個整數Li, Ri,表示小Hi詢問的一個區間[Li, Ri];如果該行描述一次商品的重量的更改,則接下來為兩個整數Pi,Wi,表示位置編號為Pi的商品的重量變更為Wi

對於100%的資料,滿足N<=10^6,Q<=10^6, 1<=Li<=Ri<=N,1<=Pi<=N, 0<weight_i, Wi<=10^4。

#include <iostream>
#include <stdio.h>
using namespace std;
int tree[3000000];
int data[1000000];
void init(int l, int r, int idx){
    if(l==r) tree[idx]=data[l];
    else{
        int mid=(l+r)/2;
        init(l, mid, idx*2+1);
        init(mid+1, r, idx*2+2);
        tree[idx]=min(tree[idx*2+1], tree[idx*2+2]);
    }
}
int get(int l, int r, int idx, int a, int b){
    if(a<=l && b>=r) return tree[idx];
    int mid=(l+r)/2;
    if(a<=mid && b>mid){
        int s = get(l, mid, idx*2+1, a, mid);
        int t = get(mid+1, r, idx*2+2, mid+1, b);
        return min(s, t);
    }
    else if(a>mid){
        return get(mid+1, r, idx*2+2, a, b);
    }
    else return get(l, mid, idx*2+1, a, b);
}
// 更新單點
int set(int l, int r, int idx, int a, int b){
    while(l!=r){
        int mid = (l+r)/2;
        if( a<=mid ) r=mid, idx=idx*2+1;
        else l=mid+1, idx=idx*2+2;
    }
    tree[idx] = b; ///@error lost this. update leaf node 
    while( idx>0 ){
        int p = (idx-1)/2, s = 4*p + 3 - idx; //@error: sb le, s = 2*p + 3 - idx ???
        tree[p]=min(tree[idx], tree[s]);
        idx = p;
    }
}
int main(){
    int n; scanf("%d", &n);
    for(int i=0; i<n; i++) scanf("%d", &data[i]);
    init(0, n-1, 0);
    int m; scanf("%d", &m);
    int c, a, b;
    for(int i=0; i<m; i++) {///@error: i<n
        scanf("%d %d %d", &c, &a, &b);
        if(c==0)  printf("%d\n", get(0, n-1, 0, a-1, b-1));
        else set(0, n-1, 0, a-1, b);
    }
    return 0;
}

區間更新問題用上文所述, 

延遲更新(對線段的更新,只更新到對應的節點),訪問先下達更新( 每讀寫一個節點前,先更新下達,從而保證訪問任何節點時祖宗節點沒有未下達的更新)

#include <iostream>
#include <vector>
#include <map>
#include <stdio.h>
using namespace std;
typedef pair<int, int> Pair;
Pair tree[800000];
int init(int l, int r, int idx, const vector<int>& a){
    tree[idx].second=0;
    if(l==r) return tree[idx].first=a[l];
    
    int mid=(l+r)/2;
    return tree[idx].first= init(l, mid, idx*2+1, a) + init(mid+1, r, idx*2+2, a);
}
void down(int l, int r, int idx){
    int d = tree[idx].second;
    if(d && l<r){
        int mid = (l+r)/2;
        tree[idx*2+1].first=d*(mid-l+1),tree[idx*2+1].second=d;
        tree[idx*2+2].first=d*(r-mid),  tree[idx*2+2].second=d;
        tree[idx].second=0;
    }
}
void set(int l, int r, int idx, int s, int t, int delta){
    down(l, r, idx);
    
    if(s<=l && t>=r) {
        tree[idx].first = delta*(r-l+1), 
        tree[idx].second= delta;
        return;
    }
  
    int mid = (l + r)/2;
    if(t<=mid) set(l, mid, idx*2+1, s, t, delta);
    else if(s>mid) set(mid+1, r, idx*2+2, s, t, delta);
    else{
        set(l, mid, idx*2+1, s, mid, delta);
        set(mid+1, r, idx*2+2, mid+1, t, delta);
    }
    
    // back
    tree[idx].first = tree[idx*2+1].first + tree[idx*2+2].first;
}
int get(int l, int r, int idx, int s, int t){
    down(l, r, idx);
    
    if(s<=l && t>=r) return tree[idx].first;
    
    int mid = ( l + r )/2;
    if(s>mid)  return get(mid+1, r, idx*2+2, s, t);
    else if(t<=mid) return get(l, mid, idx*2+1, s, t);
    else{
        return get(l, mid, idx*2+1, s, mid)+
                get(mid+1, r, idx*2+2, mid+1, t);
    }
}


int main(){
    int n; cin>>n;
    vector<int> a(n, 0);
    for(int i=0; i<n; i++){ cin>>a[i]; }
    init(0, n-1, 0, a);
    int m; cin>>m;
    char buf[1000];
    for(int i=0; i<m; i++){
        int op, a, b, c;
        cin>>op;
        if(op==0) {
            cin>>a>>b;
            cout<<get(0, n-1, 0, a-1, b-1)<<endl;
        }
        else {
            cin>>a>>b>>c;
            set(0, n-1, 0, a-1, b-1, c);
        }
    }
    return 0;
}

多延遲標籤的線段樹

延遲讀寫的正確性保證

線段樹的set/get/add等讀寫操作,都是從根節點一直向下的順序訪問;而且每訪問到一個節點,會及時把節點上的懶惰標籤向下更新

這樣做的目的是  保證每訪問到一個節點的時候,它的祖宗的更新都已經傳遞下去,即祖宗節點中不再存在延遲標籤(從而保證該節點上的各項值的正確性),這是很關鍵的。

延遲標籤的時間先後性

而線段樹的延遲存在這個規律:

祖宗節點和子節點同時存在延遲標籤的時候,以祖宗節點為準,祖宗節點的延遲更新會覆蓋子節點的

不難看出祖宗結點上的延遲標籤 來自於 後來的更新,而子節點上的延遲是來自早期的還沒來得及再往下傳遞的更新。後來的更新肯定是覆蓋之前的更新的。

多個延遲標籤的問題

線段樹上如果有多個懶惰標籤會出現什麼狀況?比如set標籤和add標籤同時出現?

誰能覆蓋誰:

一般後來的set會覆蓋早期的add。而後來的add則不會覆蓋早期的set。

舉例

描述

小Hi和小Ho都是遊戲迷,“模擬都市”是他們非常喜歡的一個遊戲,在這個遊戲裡面他們可以化身上帝模式,買賣房產。

在這個遊戲裡,會不斷的發生如下兩種事件:一種是房屋自發的漲價或者降價,而另一種是政府有關部門針對房價的硬性調控。房價的變化自然影響到小Hi和小Ho的決策,所以他們希望能夠知道任意時刻某個街道中所有房屋的房價總和是多少——但是很不幸的,遊戲本身並不提供這樣的計算。不過這難不倒小Hi和小Ho,他們將這個問題抽象了一下,成為了這樣的問題:

小Hi和小Ho所關注的街道的長度為N米,從一端開始每隔1米就有一棟房屋,依次編號為0..N,在遊戲的最開始,每棟房屋都有一個初始價格,其中編號為i的房屋的初始價格為p_i,之後共計發生了M次事件,所有的事件都是對於編號連續的一些房屋發生的,其中第i次事件如果是房屋自發的漲價或者降價,則被描述為三元組(L_i, R_i, D_i),表示編號在[L_i, R_i]範圍內的房屋的價格的增量(即正數為漲價,負數為降價)為D_i;如果是政府有關部門針對房價的硬性調控,則被描述為三元組(L_i, R_i, V_i),表示編號在[L_i, R_i]範圍內的房屋的價格全部變為V_i。而小Hi和小Ho希望知道的是——每次事件發生之後,這個街道中所有房屋的房價總和是多少。

下面是一個多標籤的實現(還有其他方法)


程式碼

#include <iostream>
#include <vector>
using namespace std;

struct node{
    int sum;
    int add;// 0 or not
    int set;//-1 or not 
};
node tree[4000000];
void down(int l, int r, int idx){
    if(tree[idx].set!=-1){
        int tmp =  tree[idx*2+1].set = tree[idx*2+2].set = tree[idx].set + tree[idx].add;//@error:
        tree[idx].add= tree[idx*2+1].add = tree[idx*2+2].add = 0;
        tree[idx].sum = (r-l+1)*tmp;
        tree[idx].set = -1;
    }    
    else if(tree[idx].add){
        tree[idx].sum += (r-l+1)*tree[idx].add;
        tree[idx*2+1].add += tree[idx].add;    
        tree[idx*2+2].add += tree[idx].add;    
        tree[idx].add = 0;
    }
}
void set(int l, int r, int idx, int a, int b, int d){
    if(a<=l && b>=r){
        tree[idx].set = d, tree[idx].add = 0;     
        return;    
    }
    down(l, r, idx);
    int mid = (l+r)/2;
    if(a<=mid) set(l, mid, idx*2+1, a, min(mid, b), d);
    if(b>mid) set(mid+1, r, idx*2+2, max(a, mid+1), b, d);
    down(l, mid, idx*2+1);
    down(mid+1, r, idx*2+2);
    tree[idx].sum = tree[idx*2+1].sum + tree[idx*2+2].sum;        
    //cout<<"set:"<<l<<","<<r<<":"<<tree[idx].sum<<endl;
}
void add(int l, int r, int idx, int a, int b, int d){
    if(a<=l && b>=r){
        tree[idx].add += d;        
        return;    
    }
    down(l, r, idx);
    int mid = (l+r)/2;
    if(a<=mid) add(l, mid, idx*2+1, a, min(mid, b), d);
    if(b>mid) add(mid+1, r, idx*2+2, max(a, mid+1), b, d);
    down(l, mid, idx*2+1);
    down(mid+1, r, idx*2+2);//@error: r not mid
    tree[idx].sum = tree[idx*2+1].sum + tree[idx*2+2].sum;        
    //cout<<"add:"<<l<<","<<r<<":"<<tree[idx].sum<<endl;
}
void init(int l, int r, int idx, vector<int> &p){
    if(l==r) {
        tree[idx].sum = p[l];
        tree[idx].set = -1, tree[idx].add = 0;
        return;
    }
    int mid = (l+r)/2;
    init(l, mid, idx*2+1, p);
    init(mid+1, r, idx*2+2, p);
    tree[idx].sum = tree[idx*2+1].sum + tree[idx*2+2].sum;
    tree[idx].set = -1, tree[idx].add = 0;
}
int get(int l, int r, int idx, int a, int b){
    down(l, r, idx);
    if(a<=l && b>=r){
        return tree[idx].sum;
    }
    int mid = (l+r)/2;
    int x = 0;
    if(a<=mid) x += get(l, mid, idx*2+1, a, min(mid, b));
    if(b>mid) x += get(mid+1, r, idx*2+2, max(a, mid+1), b);
    return x;
}
int main(){
    int n, m; cin>>n>>m; n++;
    vector<int> price(n);
    for(int i=0; i<n; i++) cin>>price[i];
    init(0, n-1, 0, price);

    for(int i=0; i<m; i++) {
        int o, a, b, d;
        cin>>o>>a>>b>>d;
        //cout<<o<<" "<<a<<" "<<b<<" "<<d<<endl;
        if(o==0) add(0, n-1, 0, a, b, d); 
        else set(0, n-1, 0, a, b, d);
        //for(int i =0; i<n; i++) cout<<get(0, n-1, 0, i, i)<<" ";cout<<endl;
        cout<<get(0, n-1, 0, 0, n-1)<<endl;
    }
    return 0;
}