樹狀陣列入門
簡介
在許多情況下,我們都要維護一個數組的字首和 \(s[i]=a[1]+a[2]+\cdots+a[i]\) ,但如果我們修改了其中一個 \(a[i]\) ,那麼 \(s[i],s[i+1],\cdots,s[n]\) 都會發生變化。所以每次修改後我們都要對字首和陣列進行維護,在最壞的情況下複雜度為 \(O(n)\) 。而樹狀陣列能很好的解決這一問題,它的修改與求和操作複雜度都是 \(O(\log n)\)
原理
我們設原陣列為 \(a\) ,樹狀陣列為 \(c\) ,對於一個長度為 \(8\) 的原陣列,下面這張圖(from oi-wiki)展示了樹狀陣列的原理:
- \(c[1]=a[1]\)
- \(c[2]=a[1]+a[2]\)
- \(c[3]=a[3]\)
- \(c[4]=a[1]+a[2]+a[3]+a[4]\)
- \(c[5]=a[5]\)
- \(c[6]=a[5]+a[6]\)
- \(c[7]=a[7]\)
- \(c[8]=a[1]+a[2]+\cdots +a[8]\)
我們用 \(f(i)\) 表示 \(c[i]\) 儲存了多少個連續的陣列 \(a\) 中的元素,即:
\[c[i]=a[i]+a[i-1]+\cdots+a[i-f(i)+1] \]為了求出 \(f(i)\) ,我們列出一張表並觀察規律:
\(i\) 的十進位制 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | 8 |
---|---|---|---|---|---|---|---|---|
\(i\) |
0001 | 0010 | 0011 | 0100 | 0101 | 0110 | 0111 | 1000 |
\(f(i)\) | 0001 | 0010 | 0001 | 0100 | 0001 | 0010 | 0001 | 1000 |
通過觀察,可以得出 \(f(i)\) 的值就是 \(i\) 的二進位制表示中從最低位到高位出現的第一個 \(1\) 和它之前所有的 \(0\) 組成的二進位制數的值,換一種說法,如果 \(i\) 的二進位制表示中從最低位到高位有 \(k\) 個連續的 \(0\) ,第 \(k+1\) 位是 \(1\),那麼 \(f(i)=2^k\) ,在樹狀陣列中, \(f\) 有一個專門的名稱 —— \(lowbit\)
那麼如何求出 \(lowbit(i)\) 呢?設 \(i\) 的二進位制表示從最低位到高位有 \(k\)
int lowbit(int x)
{
return x & (-x);
}
接下來,我們考慮,如果 \(a[i]\) 的值改變了,如何更新 \(c\) ,我們可以發現, 當 \(i\) 的二進位制表示從最低位到高位有 \(k\) 個連續的 \(0\) ,第 \(k+1\) 位是 \(1\) 時, \(i+lowbit(i)\) 會使第 \(k+1\) 位變成 \(0\) 且前 \(k\) 位依然為 \(0\) ,所以 \(lowbit(i+lowbit(i))\) 一定大於 \(lowbit(i)\) ,那麼 \(c[i+lowbit(i)]\) 一定包含了 \(a[i]\) 的值,需要更新
void update(int x, int k)
{
while(x <= n) {
c[x] += k;
x += lowbit(x);
}
}
最後考慮如何求和,根據 \(c\) 的定義可知, \(sum[i]=c[i]+c[i-lowbit[i]]+\cdots+c[1]\) ,求和的方法就很顯然了
long long get_sum(int x)
{
long long res = 0;
while(x >= 1) {
res += c[x];
x -= lowbit(x);
}
return res;
}
一個最基本的樹狀陣列模板就實現了
變式
上述樹狀陣列模板是最基本的字首和模板,支援單點更新(更新原陣列某個元素的值),區間查詢(查詢原陣列某個區間的和
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 1000000 + 5;
int n, q;
ll c[MAX_N];
int lowbit(int x)
{
return x & (-x);
}
void update(int x, int k)
{
while(x <= n) {
c[x] += k;
x += lowbit(x);
}
}
ll get_sum(int x)
{
ll res = 0;
while(x >= 1) {
res += c[x];
x -= lowbit(x);
}
return res;
}
int main()
{
scanf("%d%d", &n, &q);
for(int i = 1; i <= n; i++) {
int a;
scanf("%d", &a);
update(i, a);
}
for(int i = 1; i <= q; i++) {
int opt, x, y;
scanf("%d%d%d", &opt, &x, &y);
if(opt == 1)
update(x, y);
else
printf("%lld\n", get_sum(y) - get_sum(x - 1));
}
return 0;
}
在此基礎上稍加變通,我們就可以讓樹狀陣列實現區間更新(對原陣列某個區間的所有元素都加上一個值)\單點查詢(查詢原陣列某元素的值)和區間更新\區間查詢
區間更新\單點查詢
考慮如何將求字首和的問題轉化為求某個元素的值的問題,可以發現,差分陣列的字首和就是某個元素的值,且在進行區間更新時,只需要修改區間兩端的差分陣列的值,所以我們用樹狀陣列維護原陣列的差分陣列就能實現區間更新\單點查詢的功能了
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 1000000 + 5;
int a[MAX_N];
ll c[MAX_N];
int n, q;
int lowbit(int x)
{
return x & (-x);
}
void update(int x, int k)
{
while(x <= n) {
c[x] += k;
x += lowbit(x);
}
}
ll get_sum(int x)
{
ll res = 0;
while(x >= 1) {
res += c[x];
x -= lowbit(x);
}
return res;
}
int main()
{
scanf("%d%d", &n, &q);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
update(i, a[i] - a[i - 1]);
}
for(int i = 1; i <= q; i++) {
int opt, x, y, z;
scanf("%d", &opt);
if(opt == 1) {
scanf("%d%d%d", &x, &y, &z);
update(x, z);
update(y + 1, -z);
}else {
scanf("%d", &x);
printf("%lld\n", get_sum(x));
}
}
return 0;
}
區間更新\區間查詢
我們仍然採用差分的思路:
\(sum[i]=a[i]+a[i-1]+\cdots+a[1]=(d[1]+d[2]+\cdots+d[i])+(d[1]+\cdots+d[i-1])+\cdots+d[1]\)
進一步分解:
\(sum[i]=i(d[1]+d[2]+\cdots+d[i])-(1-1)d[1]-(2-1)d[2]-\cdots-(i-1)d[i]\)
所以只要用兩個樹狀陣列,一個維護 \(d[i]\) ,另一個維護 \((i-1)d[i]\) 即可
#include<bits/stdc++.h>
#define ll long long
using namespace std;
const int MAX_N = 1000000 + 5;
int a[MAX_N];
ll sum1[MAX_N], sum2[MAX_N];
int n, q;
int lowbit(int x)
{
return x & (-x);
}
void update(int x, int k)
{
int t = x;
while(t <= n) {
sum1[t] += k;
sum2[t] += 1ll * (x - 1) * k;
t += lowbit(t);
}
}
ll get_sum(int x)
{
ll res = 0;
int t = x;
while(t >= 1) {
res = res + x * sum1[t] - sum2[t];
t -= lowbit(t);
}
return res;
}
int main()
{
scanf("%d%d", &n, &q);
for(int i = 1; i <= n; i++) {
scanf("%d", &a[i]);
update(i, a[i] - a[i - 1]);
}
for(int i = 1; i <= q; i++) {
int opt, l, r, x;
scanf("%d", &opt);
if(opt == 1) {
scanf("%d%d%d", &l, &r, &x);
update(l, x);
update(r + 1, -x);
}else {
scanf("%d%d", &l, &r);
printf("%lld\n", get_sum(r) - get_sum(l - 1));
}
}
return 0;
}