【學習筆記】樹狀陣列
原理
原理最近暫時沒有時間寫。等我後面來補
引例1 給定一個長度為n序列a,有m次操作,操作分為兩種,一是給出一個區間,求區間之和,二是給一個數加上一個值。
如果我們直接在陣列a上做這個問題,區間和累加最多是O(n),而單點修改則是O(1);
如果我們考慮字首和優化,那麼區間和是O(1)的,而單點修改最壞則是O(n);
總的複雜度最壞都是O (mn),如果n和m都是10的5次方級別 顯然會超時
是否存在更優秀的解法呢?
有!!!樹狀陣列可以做到mlogn!!
假設 N = 2 ^ ik + 2 ^ ik - 1 + …… + 2 ^ i1;
其中 ik > ik - 1 > ik - 2 > ……> i1;
我們考慮把(0,N】這個區間拆分成以下的區間
- (x - 2 ^ i1,x];
- (x - 2 ^ i2 - 2 ^ i1, x - 2^i1]
- 一直到最後一個區間
- (0,x - 2 ^ i1 - 2 ^ i2 - 2 ^ i3 - …… - 2 ^ ik - 1]
注意以上區間均為左開右閉
以上區間的長度恰好為log(x),即x的二進位制串長度
並且我們發現對於每個區間(l ,r】來說,區間的長度恰好為r的二進位制數的最後一位1所對應的次冪
我們繼續思考 如果我們要求一個區間【1,n】的總和,可不可以把這個大區間拆分成log(n)個小區間,先求出小區間之和,再累加到我們的大區間。
那麼如何知道大區間所需要的小區間有哪些,又如何求小區間之和呢
首先我們已經知道了每個以r為右端點的區間長度,所以我們不需要知道左端點(因為我們可以自己求出來)
那麼我不妨就以右端點為下標來表示區間
我們記 c[ r ] = [ r - lowbit(r)+ 1,r ];
lowbit是取一個二進位制數的最小的一,也就是r所對應2進位制數最後的一位1,不懂的可以藍書從基礎部分看起。可以O(1)求出
下面這張圖以【1,8】這個區間為例;(摘自OI wiki)
我們發現c【1】 區間長度為1
c[ 2 ] 長度為2
c【3】長度為 1
c[4] 長度為4
不難發現所有奇數為右端點的區間長度均為1(原因是奇數的最後一位1恰好就是十進位制下的1)
假設我們要求1 ~ 6的區間和
我們首先加上c【6】,然後我們發現還得加上c【4】
那c6和c4有什麼關係呢? 注意 6 - lowbit(6)= 4,這真是太妙了!
所以我們只需要讓一個區間右端點x 不斷減去 自身的lowbit 直到它等於 0為止即可算出 1 ~x的區間和
既然1 ~ x的和算出來了,我們思考之前字首和的思想
任意一個區間l ~ r也可以被算出來
而每次計算一個區間最多隻要累加 log(n)次 太妙了!
我們再來看單點加,
顯然只有包含當前節點的父節點的值會受到影響
而我們發現每個內部節點c【x】的父節點就是c【x + lowbit(x)】,不斷做運算直到x > n即可。
單點修改(加)(log n)
void add(int x, int c) { for (int i = x; i <= n; i += lowbit(i)) tr[i] += c; }
區間求和(log n)
LL sum(int x) { LL res = 0; for (int i = x; i; i -= lowbit(i)) res += tr[i]; return res; }
引例2 把第一個問題的兩種操作改成給一個區間加上一個給定的值,或是查詢任意一個數的值
原先的問題是單點加和區間求和
而現在問題變成了區間加和單點查詢
其實很容易想到差分,單點查詢我們對差分陣列求和一遍就可以了(logn),而區間加也只需要給兩個點加上值即可(logn)
code:
#include<bits/stdc++.h> using namespace std; const int N = 100010; int tr[N]; int a[N]; int n,m; typedef long long LL; int lowbit(int x) { return x & -x; } void add(int x, int c) { for (int i = x; i <= n; i += lowbit(i)) tr[i] += c; } LL sum(int x) { LL res = 0; for (int i = x; i; i -= lowbit(i)) res += tr[i]; return res; } int main() { cin >> n >> m; for(int i = 1; i <= n; ++ i) scanf("%d",&a[i]); for(int i = 1; i <= n; ++ i) add(i,a[i] - a[i - 1]); while(m --) { string op; cin >> op; if(op == "C") { int l,r,d; scanf("%d%d%d",&l,&r,&d); add(l,d); add(r + 1,-d); } else { int x; scanf("%d",&x); printf("%lld\n",sum(x)); } } }
引例3 在前面兩個問題的資料範圍內,能否同時做到區間求和和區間加呢
1.對於區間加來說,我們同樣用到差分。
2.考慮區間和能否用到差分呢?我們會發現a1 + a2 + a3 + …… + ax
其實等於 b1 + b1 + b2 + b1 + b2 + b3 + ……+ bx;(可以在紙上畫出來)
我們不妨把它補成一個長為x + 1,寬為x的矩陣,其中每行均代表 b1 + b2 + b3 + …… + bx
此時我們發現答案等於 (x + 1)Σ(i從 1 到 n)bi 減去 Σ(i從1到 n)(bi * i);
由此我們只需要開兩個陣列,分別維護字首和即可。
程式碼:
#include <cstdio> #include <cstring> #include <algorithm> #include <iostream> using namespace std; const int N = 100010; typedef long long LL; LL tr1[N]; LL tr2[N]; int a[N]; int n,m; int lowbit(int x) { return x & -x; } void add(LL tr[],int x,LL c) { for(int i = x; i <= n; i += lowbit(i)) tr[i] += c; } LL sum(LL tr[],int x) { LL res = 0; for(int i = x;i; i -= lowbit(i)) res += tr[i]; return res; } LL prefix_sum(int x) { return (x + 1) * sum(tr1,x) - sum(tr2,x); } int main() { cin >> n >> m; for(int i = 1; i <= n; ++ i) scanf("%d",&a[i]); for(int i = 1; i <= n; ++ i) { int b = a[i] - a[i - 1]; add(tr1,i,b); add(tr2,i,1LL * b * i); } while(m --) { string op; int l,r,d; cin >> op; if(op == "C") { scanf("%d%d%d",&l,&r,&d); add(tr1,l,d); add(tr1,r + 1,-d); add(tr2,l,d * l); add(tr2,r + 1,-d * (r + 1)); } else { scanf("%d%d",&l,&r); printf("%lld\n",prefix_sum(r) - prefix_sum(l - 1)); } } return 0; }