1. 程式人生 > 實用技巧 >線段樹(簡單版)

線段樹(簡單版)

使用線段樹維護一段長度為\(n\)的區間,線段樹的空間要開\(4n\)的原因

設要維護的區間為\([1, n]\)長度為\(n\)

\(n = 2^k, k = 0, 1, 2, 3, ...\)

那麼構造的線段樹的結點個數不多不少剛好用\(2^{k + 1} - 1 < 2n\)個, 所以開\(2n\)足夠,此時的線段樹為深度為\(k + 1\)滿二叉樹。

\(n \neq 2^k\)那麼構造這個線段樹所用的結點數(包括浪費的)和一個構造區間長度為\(2^t\)的區間的線段樹所用的結點數相同其中\(2^t\)為大於n的最小的\(2\)的次冪。

下面來求 \(t\)

因為\(n < 2^t\)

那麼\(log_{2}n < t\)

所以\(t\)是大於\(log_{2}n\)的最小整數

所以\(t = \lfloor log_{2}n \rfloor + 1\)

由於構造長度為\(2^t\)的區間的線段樹需要\(2^{t + 1} - 1\)個結點,即\(2^{ \lfloor log_{2}n \rfloor + 1 + 1} - 1\)個結點, 即\(4 * 2^{ \lfloor log_{2}n \rfloor} - 1 \lt 4n\),所以為了保證不越界,開\(4n\)空間。

通過浪費一些空間,來讓線段樹具有完全二叉樹的性質(左孩子,右孩子,雙親)\(n \neq 2^k\)

時,畫出線段樹,可以明顯看出它不是完全二叉樹。

線段樹的層數問題:

\(n = 2^k\),最後一層結點數為\(n\),層數為\(log_{2}n + 1\)

\(n \ne 2^k\),最後一層結點數為\(2^t\)(包括空著的),層數為\(\lfloor log_{2}n \rfloor + 1 + 1\)

所以線段樹的層數為\(logn\)級別。

線段樹的核心操作:

  1. query(root, l, r)
  2. modify(root, idx, val)
  3. build(root, l, r) // 在l,r區間上建立線段樹
  4. pushup(root) //更新root結點維護的資訊(max, min, ...)
  5. pushdown(帶lazy標記的線段樹)

簡單版線段樹

  1. 單點修改
  2. 區間查詢(sum,max,min...)

原理:用線段樹的結點來維護每一段區間,單點修改(遞迴修改,修改所有相關的區間),區間查詢(遞迴查詢返回sum),兩者的複雜度\(O(logn)\)

模板題

給定 n 個數組成的一個數列,規定有兩種操作,一是修改某個元素,二是求子數列[a, b]的連續和。

輸入格式

第一行包含兩個整數 n 和 m,分別表示數的個數和操作次數。

第二行包含 n 個整數,表示完整數列。

接下來 m 行,每行包含三個整數 k,a,b (k=0,表示求子數列[a,b]的和;k=1,表示第 a 個數加 b)。

數列從1開始計數。

#include<iostream>
using namespace std;

const int N = 100010;

int a[N];
int n, m;

struct Node{
    int l, r;
    int sum;
}tr[N * 4];


void pushup(int u){
    tr[u].sum = tr[u << 1].sum + tr[u << 1 | 1].sum;
}

void build(int u, int l, int r){
    if(l == r) tr[u] = {l, r, a[l]};
    else{
        tr[u] = {l, r};
        int mid = l + r >> 1;
        build(u << 1, l, mid);
        build(u << 1 | 1, mid + 1, r);
        pushup(u);
    }
}

int modify(int u, int x, int v){
    if(tr[u].l == tr[u].r) tr[u].sum += v;
    else{
        int mid = tr[u].l + tr[u].r >> 1;
        if(x <= mid) modify(u << 1, x, v);
        else modify(u << 1 | 1, x, v);
        pushup(u);
    }
}

int query(int u, int l, int r){
    if(tr[u].l >= l && tr[u].r <= r) return tr[u].sum;
    int mid = tr[u].l + tr[u].r >> 1;
    int sum = 0;
    if(l <= mid) sum = query(u << 1, l, r);
    if(r > mid) sum += query(u << 1 | 1, l, r);
    
    return sum;
}

int main(){
    cin >> n >> m;
    
    for(int i = 1; i <= n; i ++) cin >> a[i];
    
    build(1, 1, n);
    
    while(m --){
        int k, a, b;
        cin >> k >> a >> b;
        
        if(k) modify(1, a, b);
        else cout << query(1, a, b) << endl;
    }
    
    return 0;
}