1. 程式人生 > 其它 >【資料結構】二維樹狀陣列

【資料結構】二維樹狀陣列

一、二維樹狀陣列

二維樹狀陣列,其實就是一維的樹狀陣列上的節點再套個樹狀陣列,就變成了二維樹狀陣列了。

const int N = 1e3 + 10;
int tr[N][N], n, m;

#define lowbit(x) (x & -x)

void add(int x, int y, int d) {
    for (int i = x; i <= n; i += lowbit(i))
        for (int j = y; j <= m; j += lowbit(j))
            tr[i][j] += d;
}
int query(int x, int y) {
    int ret = 0;
    for (int i = x; i; i -= lowbit(i))
        for (int j = y; j; j -= lowbit(j))
            ret += tr[i][j];
    return ret;
}

二、單點修改,區間查詢

LOJ #133. 二維樹狀陣列 1:單點修改,區間查詢

給出一個 \(n × m\) 的零矩陣 \(A\) ,你需要完成如下操作:

  • \(1\)\(x\)\(y\)\(k\) :表示元素 \(A\)_{\(x\) , \(y\)} 自增 \(k\)
  • \(2\)\(a\)\(b\)\(c\)\(d\): 表示詢問左上角為 \((a,b)\) ,右下角為 \((c,d)\) 的子矩陣內所有數的和

單點增加,因此可以直接加上就可以了

#include <bits/stdc++.h>
using namespace std;

typedef long long LL;
const int N = 5000; // 2^(12)=4096

int n, m;

LL tr[N][N];
#define lowbit(x) (x & -x)
void add(int x, int y, int d) {
    for (int i = x; i <= n; i += lowbit(i))
        for (int j = y; j <= m; j += lowbit(j))
            tr[i][j] += d;
}
LL query(int x, int y) {
    LL ret = 0;
    for (int i = x; i; i -= lowbit(i))
        for (int j = y; j; j -= lowbit(j))
            ret += tr[i][j];
    return ret;
}

int main() {
    //加快讀入
    ios::sync_with_stdio(false), cin.tie(0);
    cin >> n >> m;
    int opt;
    while (cin >> opt) {
        if (opt == 1) {
            int x, y, d;
            cin >> x >> y >> d;
            add(x, y, d);
        } else {
            int x1, y1, x2, y2;
            cin >> x1 >> y1 >> x2 >> y2;
            cout << query(x2, y2) - query(x1 - 1, y2) - query(x2, y1 - 1) + query(x1 - 1, y1 - 1) << '\n';
        }
    }
    return 0;
}

三、區間修改,單點查詢

LOJ #134. 二維樹狀陣列 2:區間修改,單點查詢

給出一個 \(n × m\) 的零矩陣 \(A\) ,你需要完成如下操作:

  • \(1 \, a \, b \, c \, d \, k\):表示左上角為 \((a,b)\) ,右下角為 \((c,d)\) 的子矩陣內所有數都自增加 \(k\)
  • \(2 \, x \, y\) :表示詢問元素 \(A_{x,y}\) 的值。

只需要利用一個二維樹狀陣列,維護一個二維差分陣列,單點查詢即可。

#include <bits/stdc++.h>

using namespace std;
typedef long long LL;
const int N = 5000;
int bit[N][N];
int n, m;

LL tr[N][N];
#define lowbit(x) (x & -x)
void add(int x, int y, int d) {
    for (int i = x; i <= n; i += lowbit(i))
        for (int j = y; j <= m; j += lowbit(j))
            tr[i][j] += d;
}
LL query(int x, int y) {
    LL ret = 0;
    for (int i = x; i; i -= lowbit(i))
        for (int j = y; j; j -= lowbit(j))
            ret += tr[i][j];
    return ret;
}

int main() {
    //加快讀入
    ios::sync_with_stdio(false), cin.tie(0);

    cin >> n >> m;
    int op;
    while (cin >> op) {
        if (op == 1) {
            int x1, y1, x2, y2, d;
            cin >> x1 >> y1 >> x2 >> y2 >> d;
            //二維差分
            add(x1, y1, d);
            add(x1, y2 + 1, -d);
            add(x2 + 1, y1, -d);
            add(x2 + 1, y2 + 1, d);
        } else {
            int x, y;
            cin >> x >> y;
            cout << query(x, y) << '\n';
        }
    }
    return 0;
}

四、區間修改,區間查詢

LOJ #135. 二維樹狀陣列 3:區間修改,區間查詢

給定一個大小為 \(N × M\) 的零矩陣,直到輸入檔案結束,你需要進行若干個操作,操作有兩類:

  • \(1 \, a\, b\, c\, d\, x\),表示將左上角為 \((a,b)\) ,右下角為 \((c,d)\) 的子矩陣全部加上 \(x\)

  • \(2\, a\, b\, c\, d\,\) , 表示詢問左上角為 \((a,b)\) ,右下角為 \((c,d)\) 為頂點的子矩陣的所有數字之和。

考慮字首和 \(\large S_{x,y}\) 和原陣列 \(a\) 和差分陣列 \(b\) 的關係。

\(\large \displaystyle S_{x,y}=\sum_{i=1}^x\sum_{j=1}^ya_{i,j} \\ \,\, \,\, \,\, \,\, \,\, \, =\sum_{i=1}^x\sum_{j=1}^y\sum_{k=1}^i\sum_{l=1}^jb_{k,l} \\ \,\, \,\, \,\, \,\, \,\, \, = \sum_{i=1}^x\sum_{j=1}^y[xy-y(i-1)-x(j-1)+(i-1)(j-1)]b_{i,j} \\ \,\, \,\, \,\, \,\, \,\, \, = xy\sum_{i=1}^x\sum_{j=1^y}b_{i,j}-y\sum_{i=1}^x\sum_{j=1}^y(i-1)b_{i,j}-x\sum_{i=1}^x\sum_{j=1}^y(j-1)b_{i,j}+\sum_{i=1}^x\sum_{j=1}^y(i-1)(j-1)b_{i,j} \)

為什麼可以推匯出這樣的公式?考慮單個位置 \((x,y)\)\(b_{i,j}\):
\([xy-y(i-1)-x(j-1)+(i-1)(j-1)]b_{i,j}\)(利用容斥原理),所以將每個位置加起來,就是\(s_{x,y}\)。因此,我們只需要維護四個樹狀陣列,分別維護\(b_{i,j},(i-1)b_{i,j},(j-1)b_{i,j},(i-1)(j-1)b_{i,j}\),就可以了。

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N = 2050;
int n, m;
LL a[N][N], b[N][N], c[N][N], d[N][N];

#define lowbit(x) (x & -x)

void add(int x, int y, int v) {
    for (int i = x; i <= n; i += lowbit(i)) {
        for (int j = y; j <= m; j += lowbit(j)) {
            a[i][j] += v;
            b[i][j] += (x - 1) * v;
            c[i][j] += (y - 1) * v;
            d[i][j] += (x - 1) * (y - 1) * v;
        }
    }
}

LL query(int x, int y) {
    LL ret = 0;
    for (int i = x; i; i -= lowbit(i))
        for (int j = y; j; j -= lowbit(j))
            ret += x * y * a[i][j] - y * b[i][j] - x * c[i][j] + d[i][j];
    return ret;
}

int main() {
    //加快讀入
    ios::sync_with_stdio(false), cin.tie(0);
    cin >> n >> m;
    int opt;
    while (cin >> opt) {
        int x1, y1, x2, y2;
        cin >> x1 >> y1 >> x2 >> y2;
        if (opt == 1) {
            int v;
            cin >> v;
            add(x1, y1, v);
            add(x1, y2 + 1, -v);
            add(x2 + 1, y1, -v);
            add(x2 + 1, y2 + 1, v);
        } else
            cout << query(x2, y2) - query(x1 - 1, y2) - query(x2, y1 - 1) + query(x1 - 1, y1 - 1) << '\n';
    }
    return 0;
}