1. 程式人生 > 其它 >樹狀陣列入門

樹狀陣列入門

簡介

在許多情況下,我們都要維護一個數組的字首和 \(s[i]=a[1]+a[2]+\cdots+a[i]\) ,但如果我們修改了其中一個 \(a[i]\) ,那麼 \(s[i],s[i+1],\cdots,s[n]\) 都會發生變化。所以每次修改後我們都要對字首和陣列進行維護,在最壞的情況下複雜度為 \(O(n)\) 。而樹狀陣列能很好的解決這一問題,它的修改與求和操作複雜度都是 \(O(\log n)\)

原理

我們設原陣列為 \(a\) ,樹狀陣列為 \(c\) ,對於一個長度為 \(8\) 的原陣列,下面這張圖(from oi-wiki)展示了樹狀陣列的原理:

  • \(c[1]=a[1]\)
  • \(c[2]=a[1]+a[2]\)
  • \(c[3]=a[3]\)
  • \(c[4]=a[1]+a[2]+a[3]+a[4]\)
  • \(c[5]=a[5]\)
  • \(c[6]=a[5]+a[6]\)
  • \(c[7]=a[7]\)
  • \(c[8]=a[1]+a[2]+\cdots +a[8]\)

我們用 \(f(i)\) 表示 \(c[i]\) 儲存了多少個連續的陣列 \(a\) 中的元素,即:

\[c[i]=a[i]+a[i-1]+\cdots+a[i-f(i)+1] \]

為了求出 \(f(i)\) ,我們列出一張表並觀察規律:

\(i\) 的十進位制 1 2 3 4 5 6 7 8
\(i\)
的二進位制
0001 0010 0011 0100 0101 0110 0111 1000
\(f(i)\) 0001 0010 0001 0100 0001 0010 0001 1000

通過觀察,可以得出 \(f(i)\) 的值就是 \(i\) 的二進位制表示中從最低位到高位出現的第一個 \(1\) 和它之前所有的 \(0\) 組成的二進位制數的值,換一種說法,如果 \(i\) 的二進位制表示中從最低位到高位有 \(k\) 個連續的 \(0\) ,第 \(k+1\) 位是 \(1\),那麼 \(f(i)=2^k\) ,在樹狀陣列中, \(f\) 有一個專門的名稱 —— \(lowbit\)

那麼如何求出 \(lowbit(i)\) 呢?設 \(i\) 的二進位制表示從最低位到高位有 \(k\)

個連續的 \(0\) ,第 \(k+1\) 位是 \(1\) 。對 \(i\) 取反,這些 \(0\) 就都變成了 \(1\) ,第 \(k+1\) 位就變成了 \(0\) ,再加 \(1\) ,可以發現前 \(k+1\) 位還原到了最初的值,而更高位的二進位制位則是之前的反碼,此時對 \(i\) 和取反後加 \(1\)\(i\) 進行與運算,就能得到答案了。 根據補碼的原理, \(\sim i+1\) 等價於 \(-i\) ,所以 \(lowbit(i)=i\&(-i)\)

int lowbit(int x)
{
    return x & (-x);
}

接下來,我們考慮,如果 \(a[i]\) 的值改變了,如何更新 \(c\) ,我們可以發現, 當 \(i\) 的二進位制表示從最低位到高位有 \(k\) 個連續的 \(0\) ,第 \(k+1\) 位是 \(1\) 時, \(i+lowbit(i)\) 會使第 \(k+1\) 位變成 \(0\) 且前 \(k\) 位依然為 \(0\) ,所以 \(lowbit(i+lowbit(i))\) 一定大於 \(lowbit(i)\) ,那麼 \(c[i+lowbit(i)]\) 一定包含了 \(a[i]\) 的值,需要更新

void update(int x, int k)
{
    while(x <= n) {	
        c[x] += k;
        x += lowbit(x);
    }
}

最後考慮如何求和,根據 \(c\) 的定義可知, \(sum[i]=c[i]+c[i-lowbit[i]]+\cdots+c[1]\) ,求和的方法就很顯然了

long long get_sum(int x)
{
    long long res = 0;
    while(x >= 1) {
        res += c[x];
	x -= lowbit(x);
    }
    return res;
}

一個最基本的樹狀陣列模板就實現了

變式

上述樹狀陣列模板是最基本的字首和模板,支援單點更新(更新原陣列某個元素的值),區間查詢(查詢原陣列某個區間的和

LOJ 樹狀陣列模板1

#include<bits/stdc++.h>
#define ll long long
using namespace std;

const int MAX_N = 1000000 + 5;
int n, q;
ll c[MAX_N];

int lowbit(int x)
{
    return x & (-x);
}

void update(int x, int k)
{
    while(x <= n) {
        c[x] += k;
        x += lowbit(x);
    }
}

ll get_sum(int x)
{
    ll res = 0;
    while(x >= 1) {
        res += c[x];
        x -= lowbit(x);
    }
    return res;
}

int main()
{
    scanf("%d%d", &n, &q);
    for(int i = 1; i <= n; i++) {
        int a;
        scanf("%d", &a);
        update(i, a);
    }
    for(int i = 1; i <= q; i++) {
        int opt, x, y;
        scanf("%d%d%d", &opt, &x, &y);
        if(opt == 1)
            update(x, y);
        else
            printf("%lld\n", get_sum(y) - get_sum(x - 1));
    }
    return 0;
}

在此基礎上稍加變通,我們就可以讓樹狀陣列實現區間更新(對原陣列某個區間的所有元素都加上一個值)\單點查詢(查詢原陣列某元素的值)和區間更新\區間查詢

區間更新\單點查詢

考慮如何將求字首和的問題轉化為求某個元素的值的問題,可以發現,差分陣列的字首和就是某個元素的值,且在進行區間更新時,只需要修改區間兩端的差分陣列的值,所以我們用樹狀陣列維護原陣列的差分陣列就能實現區間更新\單點查詢的功能了

LOJ 樹狀陣列模板2

#include<bits/stdc++.h>
#define ll long long
using namespace std;

const int MAX_N = 1000000 + 5;
int a[MAX_N];
ll c[MAX_N];
int n, q;

int lowbit(int x)
{
    return x & (-x);
}

void update(int x, int k)
{
    while(x <= n) {
        c[x] += k;
        x += lowbit(x);
    }
}

ll get_sum(int x)
{
    ll res = 0;
    while(x >= 1) {
        res += c[x];
        x -= lowbit(x);
    }
    return res;
}

int main()
{
    scanf("%d%d", &n, &q);
    for(int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        update(i, a[i] - a[i - 1]);
    }
    for(int i = 1; i <= q; i++) {
        int opt, x, y, z;
        scanf("%d", &opt);
        if(opt == 1) {
            scanf("%d%d%d", &x, &y, &z);
            update(x, z);
            update(y + 1, -z);
        }else {
            scanf("%d", &x);
            printf("%lld\n", get_sum(x));
        }
    }
    return 0;
}
區間更新\區間查詢

我們仍然採用差分的思路:

\(sum[i]=a[i]+a[i-1]+\cdots+a[1]=(d[1]+d[2]+\cdots+d[i])+(d[1]+\cdots+d[i-1])+\cdots+d[1]\)

進一步分解:

\(sum[i]=i(d[1]+d[2]+\cdots+d[i])-(1-1)d[1]-(2-1)d[2]-\cdots-(i-1)d[i]\)

所以只要用兩個樹狀陣列,一個維護 \(d[i]\) ,另一個維護 \((i-1)d[i]\) 即可

LOJ 樹狀陣列模板3

#include<bits/stdc++.h>
#define ll long long
using namespace std;

const int MAX_N = 1000000 + 5;
int a[MAX_N];
ll sum1[MAX_N], sum2[MAX_N];
int n, q;

int lowbit(int x)
{
    return x & (-x);
}

void update(int x, int k)
{
    int t = x;
    while(t <= n) {
        sum1[t] += k;
        sum2[t] += 1ll * (x - 1) * k;
        t += lowbit(t);
    }
}

ll get_sum(int x)
{
    ll res = 0;
    int t = x;
    while(t >= 1) {
        res = res + x * sum1[t] - sum2[t];
        t -= lowbit(t);
    }
    return res;
}

int main()
{
    scanf("%d%d", &n, &q);
    for(int i = 1; i <= n; i++) {
        scanf("%d", &a[i]);
        update(i, a[i] - a[i - 1]);
    }
    for(int i = 1; i <= q; i++) {
        int opt, l, r, x;
        scanf("%d", &opt);
        if(opt == 1) {
            scanf("%d%d%d", &l, &r, &x);
            update(l, x);
            update(r + 1, -x);
        }else {
            scanf("%d%d", &l, &r);
            printf("%lld\n", get_sum(r) - get_sum(l - 1));
        }
    }
    return 0;
}