線段樹 資料結構詳解與模板
線段樹是一個查詢和修改複雜度都為log(n)的資料結構。主要用於陣列的單點修改&&單點查詢&&區間求和&&區間修改.
另外一個擁有類似功能的是樹狀陣列,但是樹狀陣列最常用的是單點修改&&區間求和.
線段樹完全涵蓋樹狀陣列所有功能
具體區別和聯絡如下:
1.兩者在複雜度上同級, 但是樹狀陣列的常數明顯優於線段樹, 其程式設計複雜度也遠小於線段樹.
2.樹狀陣列的作用被線段樹完全涵蓋, 凡是可以使用樹狀陣列解決的問題, 使用線段樹一定可以解決, 但是線段樹能夠解決的問題樹狀陣列未必能夠解決.
說了這麼多,其實線段樹就是個
先給一份樣圖
其中,矩形內的是區間之和,區間外的是陣列下標(線段樹用陣列存資料).不難看出,線段樹的左孩子=根節點下標*2,右孩子=根節點下標*2+1,而左右孩子則是根節點將區間二分的結果.
先給出線段樹的結構體定義然後咱們再仔細講講各種(sao)操作
struct node {
int l,r,w,flag;
} a[maxn<<2]; //4倍空間
結構體裡有個延遲標記的東西,咱們下面再說這個問題
需要注意的是如果是n個數,那麼線段樹需要開4n的空間.理論上是2n-1的空間,但是你遞迴建立的時候當前節點為r,那麼左右孩子分別是2*r,2*r+1,此時編譯器並不知道遞迴已結束,因為你的結束條件是在遞迴之前的,所以編譯器會認為下標訪問出錯,也就是空間開小了,應該再開大2倍。有時候可能你發現開2,3倍的空間也可以AC,那只是因為測試資料並沒有那麼大。
至於為什麼開4倍,我從網上摘抄了一部分(反正我是看不懂
首先線段樹是一棵二叉樹,最底層有n個葉子節點(n為區間大小)
那麼由此可知,此二叉樹的高度為,可證
然後通過等比數列求和求得二叉樹的節點個數,具體公式為,(x為樹的層數,為樹的高度+1)
化簡可得,整理之後即為(近似計算忽略掉-1)
證畢
線段樹的基礎操作主要有5個:
建樹、單點查詢、單點修改、區間查詢、區間修改。
----------------------------------------------------------------------------------
建樹:會建二叉樹的話這一條也就沒什麼說的了
主要就是遞迴建樹而已
其中,k為根節點,l,r分別為左右區間
輸入n個數將其建立為線段樹只需要呼叫
build(1,1,n)即可
遞迴過程應該都能看懂(看不懂回去學二叉樹去
void build(int k,int l,int r) {
a[k].l = l,a[k].r = r;
if(a[k].l == a[k].r) {
scanf("%d",&a[k].w);
//cin >> a[k].w;
return;
}
build(k*2,l,(l+r)/2); //左
build(k*2+1,(l+r)/2+1, r);//右
a[k].w += a[k*2].w+a[k*2+1].w;//求和
}
---------------------------------------------------------------------------------------------------------------
延遲標記
這裡咱們開始用到上面的變數flag了
上面說了,線段樹是支援區間修改的,比如說開始那張圖,咱把[1,5]都加上3,總不能把[1,5],[1,3],[4,5],[1,2],[3,3],[4,4],[5,5],[1,1],[2,2]都修改了啊,這樣從第二層一直到第四層那我還要這個線段樹幹嘛,時間早爆炸了.
這時候,精髓部分來了,誒咱就只修改a[2]這個地方,也就是[1,5],下面的暫時用不上,就不管它.然後讓flag=3.
如果下一次需要用到這一部分資料的話,將flag下傳,這樣查詢哪一部分咱就算哪一部分的和,其他的就不管
要將[1,5]這部分+3但是不查詢他的話,那麼[1,5]的左右孩子也就沒有更改的必要了
這個flag就是延遲標記,有了它,我們就只需要將修改過的區域標記,等到查詢此部分的時候再向下修改就行了
以線段樹區間1-10,初值全為0,[1,5]全部+3為例:
可以看出,[1,5]的子區間內的區間和是不對的(修改後不應該為0~)
沒關係,我們只需要修改[1,5]和包含[1,5]的區間的內容即可,然後我們讓flag = 3,[1,5]的子區間暫時不用管
(黑色數字代表區間和,紅色代表flag的值)
如果接下來查詢[1,3]或者[1,5]的其他子區間,我們再向下計算區間和,對於查詢[1,3]而言,圖是這樣子的:
結論已經呼之欲出了:
如果查詢的區域有延遲標記flag,就將標記下傳,並且左右孩子的和+=flag*(左右孩子區間內所存的數)
比如說[1,5]的左孩子區間為1-3,則為3*(3-1+1) = 3*3
具體操作如下
void down(int k) {
a[k*2].flag += a[k].flag; //標記下傳
a[k*2+1].flag += a[k].flag;
a[k*2].w += a[k].flag*(a[k*2].r-a[k*2].l+1); //標記求和
a[k*2+1].w += a[k].flag *(a[k*2+1].r-a[k*2+1].l+1);
a[k].flag = 0; //下傳之後清空當前節點的標記
}
---------------------------------------------------------------------------------------------------------------
區間查詢
有了延遲標記的基礎我們就可以進行區間求和了
也是比較簡單的過程,會二分應該就能看懂
void askinterval(int k,int x,int y) {
if(a[k].l>=x && a[k].r<=y) {
ans += a[k].w; ///ans為全域性變數,記得每次查詢令ans = 0;
return;
}
if(a[k].flag)
down(k);
int buf = (a[k].l+a[k].r)/2;
if(x <= buf)
askinterval(k*2,x,y); ///遞迴查左子樹
if(y > buf)
askinterval(k*2+1,x,y); ///遞迴查右子樹
}
-----------------------------------------------------------------------------------------------------------------
區間修改
區間修改和上面的區間查詢程式碼基本相同,自行研究咯~
void changeinterval(int k,int x,int y,int z) {
if(a[k].l>=x &&a[k].r<=y) {
a[k].w += (a[k].r-a[k].l+1)*z;
a[k].flag += z;
return;
}
if(a[k].flag)
down(k);
int buf = (a[k].l+a[k].r)/2;
if(x <= buf)
changeinterval(k*2,x,y,z);
if(y > buf)
changeinterval(k*2+1,x,y,z);
a[k].w = a[k*2].w + a[k*2+1].w;
}
-----------------------------------------------------------------------------------------------------------------
單點查詢
其實單點查詢完全可以使用上面區間查詢的函式,反正都是一樣的~
不過畢竟是模板嘛,還是貼一份程式碼
void askinterval(int k,int x) {
if(a[k].l==x && a[k].r==x) {
ans = a[k].w;
return;
}
if(a[k].flag)
down(k);
int buf = (a[k].l+a[k].r)/2;
if(x <= buf)
askinterval(k*2,x);
if(y > buf)
askinterval(k*2+1,x);
}
單點修改
同樣,單點修改也可以使用區間修改的程式碼,只需要讓x和y一樣就行.
void changeinterval(int k,int x,int z) {
if(a[k].l==x &&a[k].r==x) {
a[k].w += (a[k].r-a[k].l+1)*z;
a[k].flag += z;
return;
}
if(a[k].flag)
down(k);
int buf = (a[k].l+a[k].r)/2;
if(x <= buf)
changeinterval(k*2,x,z);
if(y > buf)
changeinterval(k*2+1,x,z);
a[k].w = a[k*2].w + a[k*2+1].w;
}
老規矩,最後一道例題
解題程式碼如下:
#include <iostream>
#include <cstdio>
#include <cstdlib>
#include <cstring>
#include <cmath>
#include <algorithm>
#include <queue>
#include <string>
#include <vector>
#define For(a,b) for(ll a=0;a<b;a++)
#define mem(a,b) memset(a,b,sizeof(a))
#define _mem(a,b) memset(a,0,(b+1)<<2)
#define lowbit(a) ((a)&-(a))
#define IO do{\
ios::sync_with_stdio(false);\
cin.tie(0);\
cout.tie(0);}while(0)
using namespace std;
typedef long long ll;
const ll maxn = 2*1e5+5;
const ll INF = 0x3f3f3f3f;
struct node {
ll l,r,w,flag;
} a[maxn<<2]; //4±¶Êý×é
ll c[maxn];
ll cnt;
void build(ll k,ll l,ll r) {
a[k].l = l,a[k].r = r;
if(a[k].l == a[k].r) {
scanf("%lld",&a[k].w);
//cin >> a[k].w;
return;
}
build(k*2, l, (l+r)/2);
build(k*2+1, (l+r)/2+1, r);
a[k].w = max(a[k*2].w,a[k*2+1].w);
}
void changellerval(ll k,ll x,ll z) {
if(a[k].l==x &&a[k].r==x) {
a[k].w = z;
return;
}
ll buf = (a[k].l+a[k].r)/2;
if(x <= buf)
changellerval(k*2,x,z);
if(x > buf)
changellerval(k*2+1,x,z);
a[k].w = max(a[k*2].w, a[k*2+1].w);
}
ll ans;
void askllerval(ll k,ll x,ll y) {
if(a[k].l>=x && a[k].r<=y) {
ans = max(a[k].w,ans);
return;
}
ll buf = (a[k].l+a[k].r)/2;
if(x <= buf)
askllerval(k*2,x,y);
if(y > buf)
askllerval(k*2+1,x,y);
}
int main() {
//IO;
char buf;
ll n,m;
ll x,y,z;
while(cin >> n >> m) {
build(1,1,n);
For(i,m) {
getchar();
scanf("%c",&buf);
//cin >> buf;
if(buf == 'Q') {
scanf("%lld%lld",&x,&y);
//cin >> x >> y;
ans = 0;
askllerval(1,x,y);
printf("%lld\n",ans);
//cout << ans << endl;
} else {
scanf("%lld%lld",&x,&z);
//cin >> x >> y >> z;
changellerval(1,x,z);
}
}
}
return 0;
}
模板如下:
#include <map>
#include <cmath>
#include <cstdio>
#include <string>
#include <cstring>
#include <iostream>
#include <algorithm>
#define length (a[i].r-a[i].l+1)
#define fori(a) for(int i=0;i<a;i++)
#define forj(a) for(int j=0;j<a;j++)
#define ifor(a) for(int i=1;i<=a;i++)
#define jfor(a) for(int j=1;j<=a;j++)
#define mem(a,b) memset(a,b,sizeof(a))
#define IN freopen("in.txt","r",stdin)
#define OUT freopen("out.txt","w",stdout)
#define _mem(a,b) memset(a,0,b<<2)
#define IO do{\
ios::sync_with_stdio(false);\
cin.tie(0);\
cout.tie(0);}while(0)
using namespace std;
typedef long long ll;
const int maxn = 2*1e5+10;
const int INF = 0x3f3f3f3f;
const int inf = 0x3f;
const double EPS = 1e-7;
const double Pi = acos(-1);
const int MOD = 1e9+7;
struct Node
{
int l,r,w,flag;
Node() {};
Node(int _l,int _r,int _v,int _f){flag=_f,l=_l,r=_r,w=_v;}
int mid(){return (l+r)/2;}
};
Node a[maxn<<2];
void build(int k,int l,int r) { //建樹
a[k] = Node(l,r,0,0);
if(a[k].l == a[k].r){
cin >> a[k].w;
return;
}
build(k<<1,l,(l+r)/2);
build(k<<1|1,(l+r)/2+1,r);
a[k].w = a[k<<1].w+a[k<<1|1].w;
}
void down(int k) { //延遲標記下傳
a[k<<1].flag += a[k].flag;
a[k<<1|1].flag += a[k].flag;
a[k<<1].w += a[k].flag*(a[k<<1].r-a[k<<1].l+1);
a[k<<1|1].w += a[k].flag *(a[k<<1|1].r-a[k<<1|1].l+1);
a[k].flag = 0;
}
int res;
void update(int k,int x,int y,int z) { //區間更新
if(a[k].l>=x &&a[k].r<=y) {
a[k].w += z*(a[k].r-a[k].l+1);
a[k].flag += z;
return;
}
if(a[k].flag)
down(k);
if(x <= a[k].mid())
update(k<<1,x,y,z);
if(y > a[k].mid())
update(k<<1|1,x,y,z);
a[k].w = a[k<<1].w+a[k<<1|1].w;
}
int ans;
void query(int k,int x,int y) { //區間查詢
if(a[k].l>=x && a[k].r<=y) {
res += a[k].w;
return;
}
if(a[k].flag)
down(k);
if(x <= a[k].mid())
query(k<<1,x,y);
if(y > a[k].mid())
query(k<<1|1,x,y);
a[k].w = a[k<<1].w+a[k<<1|1].w;
}
int main()
{
IO;
//IN;
int t ;
string s;
int n,k;
int q;
int x,y,z;
cin >> n >> k;
build(1,1,n);
while(k--){
cin >> s;
cin >> x >> y;
if(s[0] == 'Q'){
res = 0;
query(1,x,y);
cout << res << endl;
}
else
{
cin >> z ;
update(1,x,y,z);
}
}
return 0;
}