1. 程式人生 > 實用技巧 >「學習筆記」樹狀陣列

「學習筆記」樹狀陣列

概念

樹狀陣列 (\(Binary\) \(Indexed\) \(Tree\)) 是一個區間查詢和單點修改複雜度都為 \(log(n)\) 的資料結構。主要用於維護序列的字首和。

長。。。這樣

其中 \(c\) 陣列為樹狀陣列。
容易發現:
\(c[1]=a[1]\)
\(c[2]=a[1]+a[2]\)
\(c[3]=a[3]\)
\(……\)

\(c[i]\) 到底是多少個 \(a[j]\) 相加呢?答案是 \(lowbit(i)\) 個。

lowbit

在這之前,我們需要了解一下樹狀數組裡的關鍵操作—— \(lowbit\)

\(lowbit(x)\) : 將 \(x\)

轉化為二進位制後,取出其最低位的 \(1\) (結果即為只保留最低位的 \(1\) 及其後面的 \(0\),並將其轉化為十進位制後的值)。

寫法

關於 \(lowbit\) 的寫法有很多種,這裡給出兩種方法:

  1. \(lowbit(x)=x-(x\) & \((x-1))\)
    解釋一下:設 \(x\) = \((A1B)_2\) ( \(A\)\(x\) 最低位的 \(1\) 之前的部分,\(B\) 為之後的部分,全部為 \(0\) ),\(x-1=(A0C)_2\)\(B,C\) 長度一致,\(C\) 全為 \(1\)),則 \(x\) & \(x-1=(11…11000…0)_2\)
    \(len_A\)\(1\)\(len_{B/C}+1\)\(0\)),所以 \(x\) 再減去這部分,就得到了 \((00…00100…00)\)\(1\) 前面 \(len_A\)\(0\)\(1\) 後面 \(len_{B/C}+1\)\(0\)),即取出了最低位的 \(1\)
  2. \(lowbit(x)=x\) & \(-x\)
    這也是最常用的寫法。
    \(-x=\)~\(x+1\),即先將 \(x\) 在二進位制下取反( \(0 \rightarrow 1\)\(1 \rightarrow 0\)),再加上 \(1\) ,最低位的 \(1\) 在先前的取反後變成 \(0\)
    ,其右邊的 \(0\) ,全部變為 \(1\) ,所以加的 \(1\) ,讓其右邊的 \(1\) 全部變回了 \(0\) ,它本身加上了進位的 \(1\) ,也變回了 \(1\)\(x\) 再與上這一部分,只有最低位的 \(1\) 的位置,在 \(x,-x\) 上都為 \(1\) ,所以也能取出了最低位的 \(1\)

作用

  1. 對原陣列(設為 \(a\))進行更新(\(update\))操作,同樣在初始化,建樹狀陣列的時候,每次輸入 \(a[i]\) ,可以通過更新達到初始化的效果。
void Update(int x, int y) {
	for (int i = x; i <= n; i += lowbit(i)) BIT[i] += y;
}

for (int i = 1; i <= n; i++) {
	scanf("%d", &a[i]);
	Update(i, a[i]);
} 
  1. 求得字首和
int Sum(int x) {
	int ans = 0;
	for (int i = x; i; i -= lowbit(i)) ans += BIT[i];
	return ans;
}

具體操作

一維

單點修改+區間查詢

Link

這裡兩個操作,可以通過更新和查詢字首和實現。

  • 給定 \(i,x\),將 \(a[i]\) 加上 \(x\)
    \(update(i,x)\)
  • 給定 \(l,r\) ,求 \(\sum_{i=l}^ra[i]\) 的值(換言之,求 \(a[l]+a[l+1]+\dots+a[r]\) 的值)。
    意思是求 \([l,r]\) 的區間和,可以通過 \(sum[r]-sum[l-1]\) 求得。
#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;

const int MAXN = 1e6 + 5;
int n, m;
int a[MAXN];
ll BIT[MAXN]; 

int lowbit(int x) { return x & (-x); }

void Update(int x, int y) {
	for (int i = x; i <= n; i += lowbit(i)) BIT[i] += y;
}

ll Sum(int x) {
	ll ans = 0;
	for (int i = x; i; i -= lowbit(i)) ans += BIT[i];
	return ans;
}

int main() {
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; i++) {
		scanf("%d", &a[i]);
		Update(i, a[i]);
	} 
	for (int i = 1; i <= m; i++) {
		int type, l, r;
		scanf("%d %d %d", &type, &l, &r);
		if (type == 1) {
			Update(l, r);
		}
		else {
			printf("%lld\n", Sum(r) - Sum(l - 1));
		}
	}
	return 0;
}

區間修改+單點查詢

Link
這裡需要用到差分,一個數組的差分陣列和它的字首和陣列是互逆的。

a[6] = {0, 1, 2, 3, 4, 5};
cf[6] = {0, 1, 1, 1, 1, 1}; //cf[i] = a[i] - a[i - 1] 差分
qzh[6] = {0, 1, 3, 6, 10, 15}; //qzh[i] = qzh[i - 1] + a[i] 字首和

cf_qzh[6] = {0, 1, 2, 3, 4, 5}; //差分陣列的字首和陣列
qzh_cf[6] = {0, 1, 2, 3, 4, 5}; //字首和陣列的差分陣列

總的來說就是:差分陣列的字首和陣列就是原陣列,字首和陣列的差分陣列也是原陣列。
我們要用到的就是,差分陣列的字首和陣列就是原陣列這一特性。

而差分陣列的區間修改是將 \(cf[l]+k,cf[r+1]-k\) (設讓 \([l,r]\) 裡的每個數加上 \(k\)\(cf\) 為原陣列的差分陣列)

對於這道題,我們不再在原陣列上建樹狀陣列了,改在差分陣列上建樹狀陣列。

每次區間修改,就對 \(cf[l]+k,cf[r+1]-k\) ,查詢每個數,即求 \([1,x]\) 的字首和。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;

const int MAXN = 1e6 + 5;
int n, m;
int a[MAXN];
ll BIT[MAXN];

int lowbit(int x) { return x & (-x); }

void Update(int x, int y) {
	for (int i = x; i <= n; i += lowbit(i)) BIT[i] += y;
}

ll Sum(int x) {
	ll ans = 0;
	for (int i = x; i; i -= lowbit(i)) ans += BIT[i];
	return ans;
}

int main() {
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; i++) {
		scanf("%d", &a[i]);
		Update(i, a[i] - a[i - 1]);
	}
	for (int i = 1; i <= m; i++) {
		int type, l, r, x;
		scanf("%d", &type);
		if (type == 1) {
			scanf("%d %d %d", &l, &r, &x);
			Update(l, x);
			Update(r + 1, -x); 
		}
		else {
			scanf("%d", &x);
			printf("%lld\n", Sum(x));
		} 
	}
	return 0;
}

區間修改+區間查詢

Link

區間修改依然像上一問那麼做,樹狀陣列依然建在差分陣列上。

區間查詢則需要稍微推一下:(設求 \([1,p]\) 的字首和,\(a\) 為差分陣列)

\(\sum_{i=1}^p\sum_{j=1}^ia[j]\) 展開一下

\(a[1]+(a[1]+a[2])+(a[1]+a[2]+a[3])+\dots+(a[1]+a[2]+a[3]+\dots+a[p])\)

\(p*a[1]+(p-1)*a[2]+(p-2)*a[3]+\dots+2*a[p-1]+a[p]\)

因為樹狀陣列最主要的作用就是求字首和,所以要把式子落到字首和上。

\(p*(a[1]+a[2]+\dots+a[p])-0*a[1]-1*a[2]-\dots-(p-1)*a[p]\)

\(\sum_{i=1}^pa[i]*p-\sum_{j=1}^pa[j]*(j-1)\)

所以只需要用兩個樹狀陣列,一個維護 \(a[i]\) 的字首和,一個維護 \(a[i]*(i-1)\) 的字首和。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;

const int MAXN = 1e6 + 5;
int n, m;
int a[MAXN];
ll BIT1[MAXN], BIT2[MAXN];

int lowbit(int x) {
	return x & (-x);
}

void Update(int x, int y) {
	for (int i = x; i <= n; i += lowbit(i))
		BIT1[i] += y, BIT2[i] += (ll)(x - 1) * y;
}

ll Sum(int x) {
	ll ans = 0;
	for (int i = x; i; i -= lowbit(i))
		ans += BIT1[i] * x - BIT2[i];
	return ans;
}

int main() {
	scanf("%d %d", &n, &m);
	for (int i = 1; i <= n; i++) {
		scanf("%d", &a[i]);
		Update(i, a[i] - a[i - 1]);
	}
	for (int i = 1; i <= m; i++) {
		int type, l, r, x;
		scanf("%d %d %d", &type, &l, &r);
		if (type == 1) {
			scanf("%d", &x);
			Update(l, x);
			Update(r + 1, -x);
		}
		else {
			printf("%lld\n", Sum(r) - Sum(l - 1));
		}
	}
	return 0;
}

求逆序對

樹狀陣列還可以求得逆序對,很巧妙。

首先你需要將輸入進來的序列 \(a\) 離散化(當數值範圍較小可以免去離散化)

離散化是什麼呢?當資料只與它們之間的相對大小有關,而與具體是多少無關時,可以進行離散化。

\(e.g.\)

原陣列: 	10000  1  8  99  500
離散化陣列:   5  1  2  3  4

實現方法主要有兩種,這裡只介紹用 \(STL\) 實現的

#include <algorithm>
using namespace std;

int query(int x) {
	return lower_bound(a + 1, a + 1 + m, x) - a; //二分查詢 log(m)
}

for (int i = 1; i <= n; i++) {
	scanf("%d", &a[i]);
	b[i] = a[i];
}
sort(a + 1, a + 1 + n);
m = unique(a + 1, a + 1 + n) - (a + 1); //去重
for (int i = 1; i <= n; i++) {
	b[i] = query(b[i]);
}

言歸正傳,在離散化之後,從 \(1\)~\(n\) 迴圈,\(i\) 可以表示進了樹狀陣列的數的個數,每次將 \(update(b[i],1)\) 表示又進了一個第 \(b[i]\) 小的數,這時候的 \(Sum(b[i])\) ,表示的就是進入了樹狀陣列中的數中 \(\leq\) 當前數的數的總數, 則 \(i-Sum(b[i])\) 就表示 \(>\) 當前數且在當前數之前的數,因為此時在樹狀陣列中的數都是在當前數之前的,還有它自己,不 \(\leq\) 它那就是 \(>\) 它,所以 \(i-Sum(b[i])\) 即為當前數的逆序對數,累加即可。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;

const int MAXN = 1005;
int n, m, ans;
int a[MAXN], b[MAXN], BIT[MAXN];

int lowbit(int x) {
	return x & (-x);
}

void Update(int x, int y) {
	for (int i = x; i <= n; i += lowbit(i)) BIT[i] += y;
}

int Sum(int x) {
	int ans = 0;
	for (int i = x; i; i -= lowbit(i)) ans += BIT[i];
	return ans;
}

int query(int x) {
	return lower_bound(a + 1, a + 1 + m, x) - a;
}

int main() {
	scanf("%d", &n);
	for (int i = 1; i <= n; i++) {
		scanf("%d", &a[i]);
		b[i] = a[i];
	}
	sort(a + 1, a + 1 + n);
	m = unique(a + 1, a + 1 + n) - (a + 1);
	for (int i = 1; i <= n; i++) {
		b[i] = query(b[i]);
	}
	for (int i = 1; i <= n; i++) {
		Update(b[i], 1);
		ans += i - Sum(b[i]);
	}
	printf("%d", ans);
	return 0;
}

二維

單點修改+區間查詢

Link

這是二維樹狀數組裡最基本的操作,其實和一維的大同小異,定義是介樣的(\(baidu\)

\(BIT[x][y]=a[i][j](x-lowbit(x) + 1 \leq i \leq x,y-lowbit(y) + 1 \leq j \leq y)\)

有了定義,實現就很簡單了,跟一維沒什麼兩樣。

void Update(int x, int y, int z) { //單點修改
    for (int i = x; i <= n; i += lowbit(i))
        for (int j = y; j <= m; j += lowbit(j)) BIT[i][j] += z;
}

int Sum(int x, int y) { //查詢左上角(1,1)右下角(x,y)的矩陣的元素之和
    int ans = 0;
    for (int i = x; i; i -= lowbit(i))
        for (int j = y; j; j -= lowbit(j)) ans += BIT[i][j];
    return ans;
}

只不過在求和的時候,像二維字首和一樣,需要用到容斥,結合圖來說:


我們要求的是紅色矩陣中元素的和,設紅色矩陣左上角座標 \((a, b)\),右下角 \((c,d)\)

所以紅色矩陣可以看作,左上角為 \((1,1)\) 右下角為 \((c,d)\) 的矩陣減去左上角為 \((1,1)\) 右下角為 \((c,b)\) 的矩陣和左上角為 \((1,1)\) 右下角為 \((a,d)\) 的矩陣(即減去兩個藍色矩陣),這樣就只剩下了紅色矩陣,但由於綠色矩陣被減了兩次,再加上即可。

\(\sum_{i=a}^c \sum_{j=b}^da[i]=Sum(c, d) - Sum(a - 1, d) - Sum(c, b - 1) + Sum(a - 1, b - 1)\)

#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;

const int MAXN = 1 << 12;
int n, m, type;
ll BIT[MAXN + 5][MAXN + 5];

int lowbit(int x) { return x & (-x); }

void Update(int x, int y, int z) {
    for (int i = x; i <= n; i += lowbit(i))
        for (int j = y; j <= m; j += lowbit(j)) BIT[i][j] += z;
}

ll Sum(int x, int y) {
    ll ans = 0;
    for (int i = x; i; i -= lowbit(i))
        for (int j = y; j; j -= lowbit(j)) ans += BIT[i][j];
    return ans;
}

int main() {
    scanf("%d %d", &n, &m);
    while (~scanf("%d", &type)) {
        int a, b, c, d;
        if (type == 1) {
            scanf("%d %d %d", &a, &b, &c);
            Update(a, b, c);
        } else {
            scanf("%d %d %d %d", &a, &b, &c, &d);
            printf("%lld\n", Sum(c, d) - Sum(a - 1, d) - Sum(c, b - 1) + Sum(a - 1, b - 1));
        }
    }
    return 0;
}

區間修改+單點查詢

Link

二維採用的方法依然是差分,但由於初始時是一個零矩陣(所有元素均為 \(0\)),所以只用進行相應的修改。舉個例子

0  0  0  0  0
0  0  0  0  0
0  0  0  0  0
0  0  0  0  0
0  0  0  0  0 

想要到達這樣的效果:將左上角為 \((a,b)\) 右下角為 \((c,d)\) 的矩陣內的元素所有加上 \(x\),即:

0  0  0  0  0
0  x  x  x  0
0  x  x  x  0
0  x  x  x  0
0  0  0  0  0

我們可以將 \((a,b)\) 加上 \(x\)\((c+1,d+1)\) 加上 \(x\)\((a,d+1)\)\((c+1,b)\) 減去 \(x\)

0  0  0  0  0
0 +x  0  0 -x
0  0  0  0  0
0  0  0  0  0
0 -x  0  0 +x

由於在差分上建樹狀陣列,所以單點查詢還是求其字首和。

#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;

const int MAXN = 1 << 12;
int n, m, type;
ll BIT[MAXN + 5][MAXN + 5];

int lowbit(int x) { return x & (-x); }

void Update(int x, int y, int z) {
    for (int i = x; i <= n; i += lowbit(i))
        for (int j = y; j <= m; j += lowbit(j)) {
            BIT[i][j] += z;
        }
}

ll Sum(int x, int y) {
    ll ans = 0;
    for (int i = x; i; i -= lowbit(i))
        for (int j = y; j; j -= lowbit(j)) {
            ans += BIT[i][j];
        }
    return ans;
}

int main() {
    scanf("%d %d", &n, &m);
    while (~scanf("%d", &type)) {
        if (type == 1) {
            int a, b, c, d, x;
            scanf("%d %d %d %d %d", &a, &b, &c, &d, &x);
            Update(a, b, x);
            Update(c + 1, d + 1, x);
            Update(a, d + 1, -x);
            Update(c + 1, b, -x);
        } else {
            int a, b;
            scanf("%d %d", &a, &b);
            printf("%lld\n", Sum(a, b));
        }
    }
    return 0;
}

區間修改+區間查詢

Link

首先還是在二維查分陣列上建樹狀陣列(設求左上角 \((1,1)\) ,右下角 \((a,b)\) 的矩陣的元素和,\(a\) 為差分陣列)

\(Sum=\sum_{i=1}^x\sum_{j=1}^y\sum_{x=1}^i\sum_{y=1}^ja[x][y]\)

類比一維時的做法,我們看看每個 \(a[x][y]\) 出現了多少次:

\(a[1][1]\) 出現了 \(x * y\) 次;
\(a[1][2]\) 出現了 \(x * (y-1)\) 次;
\(a[2][1]\) 出現了 \((x-1) * y\) 次;
\(a[2][2]\) 出現了 \((x - 1) * (y - 1)\) 次;
\(\dots\)

找找規律發現 \(a[x][y]\) 出現 \((x-i+1) * (y-j+1)\)

即:

\(=\sum_{i=1}^x\sum_{j=1}^ya[i][j] * (x-i+1) * (y-j+1)\)

\(=\sum_{i=1}^x\sum_{j=1}^ya[i][j] * [(x+1)*(y+1)-(x+1)*j-(y+1)*i+i*j]\)

\(=\sum_{i=1}^x\sum_{j=1}^ya[i][j]*(x+1)*(y+1)-\sum_{i=1}^x\sum_{j=1}^ya[i][j]*(x+1)*j-\sum_{i=1}^x\sum_{j=1}^ya[i][j]*(y+1)*i+\sum_{i=1}^x\sum_{j=1}^ya[i][j] * i*j\)

由這個式子不難看出我們需要四個樹狀陣列分別維護

\(a[i][j]\)\(a[i][j]*j\)\(a[i][j]*i\)\(a[i][j] * i*j\)

的字首和。。。加個快讀更快

#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;

const int MAXN = 1 << 11;
int n, m, type;
ll BIT1[MAXN + 5][MAXN + 5], BIT2[MAXN + 5][MAXN + 5], BIT3[MAXN + 5][MAXN + 5], BIT4[MAXN + 5][MAXN + 5];

int lowbit(int x) { return x & (-x); }

void Update(int x, int y, int z) {
    for (int i = x; i <= n; i += lowbit(i))
        for (int j = y; j <= m; j += lowbit(j)) {
        	BIT1[i][j] += z;
        	BIT2[i][j] += (ll)x * z;
        	BIT3[i][j] += (ll)y * z;
        	BIT4[i][j] += (ll)x * y * z;
		}
}

ll Sum(int x, int y) {
    ll ans = 0;
    for (int i = x; i; i -= lowbit(i))
        for (int j = y; j; j -= lowbit(j)) {
			ans += BIT1[i][j] * (x + 1) * (y + 1) - BIT2[i][j] * (y + 1) - BIT3[i][j] * (x + 1) + BIT4[i][j];
		}
    return ans;
}

void read(int &x) {
	x = 0; 
	int f = 1;
	char s = getchar(); 
	while (s > '9' || s < '0') { 
		if (s == '-') f = -1; 
		s = getchar(); 
	}
	while (s >= '0' && s <= '9') { 
		x = (x << 3) + (x << 1) + (s - '0');
		s = getchar(); 
	}
	x *= f;
}

int main() {
    read(n); read(m);
    while (~scanf("%d", &type)) {
        if (type == 1) {
        	int a, b, c, d, x;
        	read(a); read(b); read(c); read(d); read(x);
            Update(a, b, x);
            Update(c + 1, d + 1, x);
            Update(a, d + 1, -x);
            Update(c + 1, b, -x);
        } 
		else {
			int a, b, c, d;
			read(a); read(b); read(c); read(d);
            printf("%lld\n", Sum(c, d) - Sum(a - 1, d) - Sum(c, b - 1) + Sum(a - 1, b - 1));
        }
    }
    return 0;
}

習題

Problem 數星星

讀題時還想這道題挺複雜,直到我看到了輸入格式。真香

把輸入的點放在座標系上,相當於每次從下往上,從左往右這樣輸入。

我們分別看看輸入的點的橫縱座標, \(y\) 座標沒什麼價值,因為在它之前輸入的一定 \(\leq\) 當前 \(y\) 座標(顯然的啊),\(x\) 座標則很有價值,因為它的等級是看星星的左下方(包含正左和正下)的星星總和,當前輸入的星星的等級需要考慮的星星已經囊括在了已經輸入的星星之中,它的等級即為已輸入的星星中 \(x\) 座標為 \(\leq\) 當前 \(x\) 座標的星星個數。

5  o  o  o  o
4  o  *  !  o
3  o  *  *  o
2  *  *  o  *
1  2  3  4  5

舉個例子,設當前輸入到了 \(!\) 這個位置的星星,它的等級即為左下角為 \((1,1)\),右上角為 \((4,4)\) 這個矩陣中的星星個數(不包括它自己),這些星星已在輸入當前星星前被輸入了。

所以我們可以在 \(x\) 軸上建樹狀陣列統計每個 \(x\) 座標上的星星個數,每次輸入直接可以得到該星星的等級 \(Sum(x)\) ,並更新當前 \(x\) 座標上的星星數量(加上一)。

注意,這裡輸入的座標有\(0\),而樹狀陣列下標不能為\(0\),所以要加\(1\)處理!

#include <cstdio>
#include <cstring>
#include <algorithm>
#define ll long long
using namespace std;

const int MAXN = 32005;
const int MAXM = 15005;
int n;
int BIT[MAXN], lev[MAXM];

struct node {
    int x;
    int y;
} a[MAXN];

int lowbit(int x) { return x & (-x); }

void update(int x, int y) {
    for (; x <= 32001; x += lowbit(x)) BIT[x] += y; //因為加 1了,所以要遍歷到32001
}

int query(int x) {
    int ans = 0;
    for (; x; x -= lowbit(x)) ans += BIT[x];
    return ans;
}

int main() {
    scanf("%d", &n);
    for (int i = 1; i <= n; i++) {
        scanf("%d %d", &a[i].x, &a[i].y);
        int sum = query(a[i].x + 1);
        update(a[i].x + 1, 1); //+1
        lev[sum]++;
    }
    for (int i = 0; i < n; i++) {
        printf("%d\n", lev[i]);
    }
    return 0;
}