1. 程式人生 > >線段樹 資料結構詳解與模板

線段樹 資料結構詳解與模板

線段樹是一個查詢和修改複雜度都為log(n)的資料結構。主要用於陣列的單點修改&&單點查詢&&區間求和&&區間修改.

另外一個擁有類似功能的是樹狀陣列但是樹狀陣列最常用的是單點修改&&區間求和.

線段樹完全涵蓋樹狀陣列所有功能

具體區別和聯絡如下:

1.兩者在複雜度上同級, 但是樹狀陣列的常數明顯優於線段樹, 其程式設計複雜度也遠小於線段樹.

2.樹狀陣列的作用被線段樹完全涵蓋, 凡是可以使用樹狀陣列解決的問題, 使用線段樹一定可以解決, 但是線段樹能夠解決的問題樹狀陣列未必能夠解決.

說了這麼多,其實線段樹就是個

二叉樹而已,只不過葉子節點記錄的是區間之間的和而已

先給一份樣圖

其中,矩形內的是區間之和,區間外的是陣列下標(線段樹用陣列存資料).不難看出,線段樹的左孩子=根節點下標*2,右孩子=根節點下標*2+1,而左右孩子則是根節點將區間二分的結果.

先給出線段樹的結構體定義然後咱們再仔細講講各種(sao)操作

struct node {
    int l,r,w,flag;
} a[maxn<<2]; //4倍空間


結構體裡有個延遲標記的東西,咱們下面再說這個問題

需要注意的是如果是n個數,那麼線段樹需要開4n的空間.理論上是2n-1的空間,但是你遞迴建立的時候當前節點為r,那麼左右孩子分別是2*r,2*r+1,此時編譯器並不知道遞迴已結束,因為你的結束條件是在遞迴之前的,所以編譯器會認為下標訪問出錯,也就是空間開小了,應該再開大2倍。有時候可能你發現開2,3倍的空間也可以AC,那只是因為測試資料並沒有那麼大。

至於為什麼開4倍,我從網上摘抄了一部分(反正我是看不懂

            首先線段樹是一棵二叉樹,最底層有n個葉子節點(n為區間大小)

            那麼由此可知,此二叉樹的高度為,可證

        然後通過等比數列求和求得二叉樹的節點個數,具體公式為,(x為樹的層數,為樹的高度+1)

            化簡可得,整理之後即為(近似計算忽略掉-1)

             證畢

線段樹的基礎操作主要有5個:

建樹、單點查詢、單點修改、區間查詢、區間修改。

----------------------------------------------------------------------------------

建樹:會建二叉樹的話這一條也就沒什麼說的了

主要就是遞迴建樹而已

其中,k為根節點,l,r分別為左右區間

輸入n個數將其建立為線段樹只需要呼叫

build(1,1,n)即可

遞迴過程應該都能看懂(看不懂回去學二叉樹去

void build(int k,int l,int r) {
    a[k].l = l,a[k].r = r;
    if(a[k].l == a[k].r) {
        scanf("%d",&a[k].w);
        //cin >> a[k].w;
        return;
    }
    build(k*2,l,(l+r)/2);    //左
    build(k*2+1,(l+r)/2+1, r);//右
    a[k].w += a[k*2].w+a[k*2+1].w;//求和
}

---------------------------------------------------------------------------------------------------------------

延遲標記

這裡咱們開始用到上面的變數flag了

上面說了,線段樹是支援區間修改的,比如說開始那張圖,咱把[1,5]都加上3,總不能把[1,5],[1,3],[4,5],[1,2],[3,3],[4,4],[5,5],[1,1],[2,2]都修改了啊,這樣從第二層一直到第四層那我還要這個線段樹幹嘛,時間早爆炸了.

這時候,精髓部分來了,誒咱就只修改a[2]這個地方,也就是[1,5],下面的暫時用不上,就不管它.然後讓flag=3.

如果下一次需要用到這一部分資料的話,將flag下傳,這樣查詢哪一部分咱就算哪一部分的和,其他的就不管

                    要將[1,5]這部分+3但是不查詢他的話,那麼[1,5]的左右孩子也就沒有更改的必要了

這個flag就是延遲標記,有了它,我們就只需要將修改過的區域標記,等到查詢此部分的時候再向下修改就行了

以線段樹區間1-10,初值全為0,[1,5]全部+3為例:

可以看出,[1,5]的子區間內的區間和是不對的(修改後不應該為0~)

沒關係,我們只需要修改[1,5]和包含[1,5]的區間的內容即可,然後我們讓flag = 3,[1,5]的子區間暫時不用管

(黑色數字代表區間和,紅色代表flag的值)

如果接下來查詢[1,3]或者[1,5]的其他子區間,我們再向下計算區間和,對於查詢[1,3]而言,圖是這樣子的:

結論已經呼之欲出了:

如果查詢的區域有延遲標記flag,就將標記下傳,並且左右孩子的和+=flag*(左右孩子區間內所存的數)

比如說[1,5]的左孩子區間為1-3,則為3*(3-1+1) = 3*3

具體操作如下

void down(int k) {
    a[k*2].flag += a[k].flag;            //標記下傳
    a[k*2+1].flag += a[k].flag;

    a[k*2].w += a[k].flag*(a[k*2].r-a[k*2].l+1);    //標記求和
    a[k*2+1].w += a[k].flag *(a[k*2+1].r-a[k*2+1].l+1);
    a[k].flag = 0;                        //下傳之後清空當前節點的標記
}

---------------------------------------------------------------------------------------------------------------

區間查詢

有了延遲標記的基礎我們就可以進行區間求和了

也是比較簡單的過程,會二分應該就能看懂

void askinterval(int k,int x,int y) {
    if(a[k].l>=x && a[k].r<=y) {
        ans += a[k].w;            ///ans為全域性變數,記得每次查詢令ans = 0;
        return;
    }
    if(a[k].flag)
        down(k);
    int buf = (a[k].l+a[k].r)/2;
    if(x <= buf)
        askinterval(k*2,x,y);           ///遞迴查左子樹
    if(y > buf)
        askinterval(k*2+1,x,y);         ///遞迴查右子樹
}

-----------------------------------------------------------------------------------------------------------------

區間修改

區間修改和上面的區間查詢程式碼基本相同,自行研究咯~

void changeinterval(int k,int x,int y,int z) {
    if(a[k].l>=x &&a[k].r<=y) {
        a[k].w += (a[k].r-a[k].l+1)*z;
        a[k].flag += z;
        return;
    }
    if(a[k].flag)
        down(k);
    int buf = (a[k].l+a[k].r)/2;
    if(x <= buf)
        changeinterval(k*2,x,y,z);
    if(y > buf)
        changeinterval(k*2+1,x,y,z);
    a[k].w = a[k*2].w + a[k*2+1].w;
}

-----------------------------------------------------------------------------------------------------------------

單點查詢

其實單點查詢完全可以使用上面區間查詢的函式,反正都是一樣的~

不過畢竟是模板嘛,還是貼一份程式碼

void askinterval(int k,int x) {
    if(a[k].l==x && a[k].r==x) {
        ans = a[k].w;
        return;
    }
    if(a[k].flag)
        down(k);
    int buf = (a[k].l+a[k].r)/2;
    if(x <= buf)
        askinterval(k*2,x);
    if(y > buf)
        askinterval(k*2+1,x);
}

單點修改

同樣,單點修改也可以使用區間修改的程式碼,只需要讓x和y一樣就行.

void changeinterval(int k,int x,int z) {
    if(a[k].l==x &&a[k].r==x) {
        a[k].w += (a[k].r-a[k].l+1)*z;
        a[k].flag += z;
        return;
    }
    if(a[k].flag)
        down(k);
    int buf = (a[k].l+a[k].r)/2;
    if(x <= buf)
        changeinterval(k*2,x,z);
    if(y > buf)
        changeinterval(k*2+1,x,z);
    a[k].w = a[k*2].w + a[k*2+1].w;
}

老規矩,最後一道例題

解題程式碼如下:

#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <string>
#include <vector>
#define For(a,b) for(ll a=0;a<b;a++)
#define mem(a,b) memset(a,b,sizeof(a))
#define _mem(a,b) memset(a,0,(b+1)<<2)
#define lowbit(a) ((a)&-(a))
#define IO do{\
    ios::sync_with_stdio(false);\
    cin.tie(0);\
    cout.tie(0);}while(0)

using namespace std;
typedef long long ll;
const ll maxn =  2*1e5+5;
const ll INF = 0x3f3f3f3f;
struct node {
    ll l,r,w,flag;
} a[maxn<<2]; //4±¶Êý×é
ll c[maxn];
ll cnt;
void build(ll k,ll l,ll r) {
    a[k].l = l,a[k].r = r;
    if(a[k].l == a[k].r) {
        scanf("%lld",&a[k].w);
        //cin >> a[k].w;
        return;
    }
    build(k*2, l, (l+r)/2);
    build(k*2+1, (l+r)/2+1, r);
    a[k].w = max(a[k*2].w,a[k*2+1].w);
}

void changellerval(ll k,ll x,ll z) {
    if(a[k].l==x &&a[k].r==x) {
        a[k].w = z;
        return;
    }
    ll buf = (a[k].l+a[k].r)/2;
    if(x <= buf)
        changellerval(k*2,x,z);
    if(x > buf)
        changellerval(k*2+1,x,z);
    a[k].w = max(a[k*2].w, a[k*2+1].w);
}
ll ans;
void askllerval(ll k,ll x,ll y) {
    if(a[k].l>=x && a[k].r<=y) {
        ans = max(a[k].w,ans);
        return;
    }
    ll buf = (a[k].l+a[k].r)/2;
    if(x <= buf)
        askllerval(k*2,x,y);
    if(y > buf)
        askllerval(k*2+1,x,y);
}

int main() {
    //IO;

    char buf;
    ll n,m;
    ll x,y,z;
    while(cin >> n >> m) {
        build(1,1,n);
        For(i,m) {
            getchar();
            scanf("%c",&buf);
            //cin >> buf;
            if(buf == 'Q') {
                scanf("%lld%lld",&x,&y);
                //cin >> x >> y;
                ans = 0;
                askllerval(1,x,y);
                printf("%lld\n",ans);
                //cout << ans << endl;
            } else {
                scanf("%lld%lld",&x,&z);
                //cin >> x >> y >> z;
                changellerval(1,x,z);
            }
        }
    }
    return 0;
}

模板如下:

#include <map>
#include <cmath>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
#define length (a[i].r-a[i].l+1)
#define fori(a) for(int i=0;i<a;i++)
#define forj(a) for(int j=0;j<a;j++)
#define ifor(a) for(int i=1;i<=a;i++)
#define jfor(a) for(int j=1;j<=a;j++)
#define mem(a,b) memset(a,b,sizeof(a))
#define IN freopen("in.txt","r",stdin)
#define OUT freopen("out.txt","w",stdout)
#define _mem(a,b) memset(a,0,b<<2)
#define IO do{\
    ios::sync_with_stdio(false);\
    cin.tie(0);\
    cout.tie(0);}while(0)

using namespace std;
typedef long long ll;
const int maxn =  2*1e5+10;
const int INF = 0x3f3f3f3f;
const int inf = 0x3f;
const double EPS = 1e-7;
const double Pi = acos(-1);
const int MOD = 1e9+7;
struct Node
{
    int l,r,w,flag;
    Node() {};
    Node(int _l,int _r,int _v,int _f){flag=_f,l=_l,r=_r,w=_v;}
    int mid(){return (l+r)/2;}
};
Node a[maxn<<2];
void build(int k,int l,int r) {            //建樹
    a[k] = Node(l,r,0,0);
    if(a[k].l == a[k].r){
        cin >> a[k].w;
        return;
    }
    build(k<<1,l,(l+r)/2);
    build(k<<1|1,(l+r)/2+1,r);
    a[k].w = a[k<<1].w+a[k<<1|1].w;
}

void down(int k) {                        //延遲標記下傳
    a[k<<1].flag += a[k].flag;
    a[k<<1|1].flag += a[k].flag;
    a[k<<1].w += a[k].flag*(a[k<<1].r-a[k<<1].l+1);
    a[k<<1|1].w += a[k].flag *(a[k<<1|1].r-a[k<<1|1].l+1);
    a[k].flag = 0;
}
int res;

void update(int k,int x,int y,int z) {        //區間更新
    if(a[k].l>=x &&a[k].r<=y) {
        a[k].w += z*(a[k].r-a[k].l+1);
        a[k].flag += z;
        return;
    }
    if(a[k].flag)
        down(k);
    if(x <= a[k].mid())
        update(k<<1,x,y,z);
    if(y > a[k].mid())
        update(k<<1|1,x,y,z);
    a[k].w = a[k<<1].w+a[k<<1|1].w;
}
int ans;
void query(int k,int x,int y) {            //區間查詢
    if(a[k].l>=x && a[k].r<=y) {
        res += a[k].w;
        return;
    }
    if(a[k].flag)
        down(k);
    if(x <= a[k].mid())
        query(k<<1,x,y);
    if(y > a[k].mid())
        query(k<<1|1,x,y);
    a[k].w = a[k<<1].w+a[k<<1|1].w;
}

int main()
{
    IO;
    //IN;
    int t ;
    string s;
    int n,k;
    int q;
    int x,y,z;
    cin >> n >> k;
    build(1,1,n);
    while(k--){
        cin >> s;
        cin >> x >> y;
        if(s[0] == 'Q'){
            res = 0;
            query(1,x,y);
            cout << res << endl;
        }
        else
        {
            cin >> z ;
            update(1,x,y,z);
        }
    }
    return 0;
}