1. 程式人生 > 實用技巧 >樹上差分學習筆記

樹上差分學習筆記

引入

1172. 祖孫詢問

Code:板子

P3379 【模板】最近公共祖先(LCA)

Code:又一道板子

P3258 [JLOI2014]松鼠的新家

樹上差分板子,讓每個點到下一個點之間的路徑

Code:差分板子

P3128 [USACO15DEC]Max Flow P

Code:也是差分板子

總結一下樹上差分:

分為點差分和邊差分,一個是統計路徑上點上資訊,一個是統計路徑上邊上資訊,舉個例子:路徑點數路徑邊數

  • 點差分:在兩個兒子處加1,在lca處減1,在lca的父親處減1

  • 邊差分:在兩個兒子處加1,在lca處減2

重點:“減還是加”?“統計從上到下還是子樹資訊?”

第一個問題上面已經說到,第二個問題:將差分變為真實值是統計子樹資訊,此時每個點的單點答案已經正確(例如:一個點被覆蓋的次數),也就是,完成了“區間修改,單點查詢”

如果還想查詢一條鏈的資訊,就要從上到下統計資訊,然後在lca處容斥即可(分點差分和邊差分),也就是“區間修改,區間查詢”

實戰

商人 5.0

要用到上面所說的“區間修改,區間查詢”

Code


Code1:
#include <bits/stdc++.h>
using namespace std;
const int MA = 1 << 23;
char buf[MA], *p1 = buf, *p2 = buf;
#define gc()                                                            \
    (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, MA, stdin), p1 == p2) \
         ? EOF                                                          \
         : *p1++)
inline int read() {
    int f = 0, r = 1;
    char ch = gc();
    while (!isdigit(ch)) {
        if (ch == '-') r = -1;
        ch = gc();
    }
    while (isdigit(ch)) {
        f = (f << 1) + (f << 3) + (ch ^ 48);
        ch = gc();
    }
    return f * r;
}
const int N = 40005;
int n, m, root, fa[N][20], dep[N], q[N];
vector<int> son[N];
inline void get_fa_bfs() {
    memset(dep, 0x3f, sizeof dep);
    int hh = 0, tt = 0;
    dep[0] = 0, dep[root] = 1, q[0] = root;
    while (hh <= tt) {
        int x = q[hh++];
        for (int i = 0; i < son[x].size(); ++i) {
            int y = son[x][i];
            if (dep[y] < dep[x] + 1) continue;
            dep[y] = dep[x] + 1, fa[y][0] = x, q[++tt] = y;
            for (int k = 1; (1 << k) <= n; ++k)
                fa[y][k] = fa[fa[y][k - 1]][k - 1];
        }
    }
}
inline int lca(int U, int V) {
    if (dep[U] < dep[V]) swap(U, V);
    for (int k = 15; ~k; --k)
        if (dep[fa[U][k]] >= dep[V]) U = fa[U][k];
    if (U == V) return U;
    for (int k = 15; ~k; --k)
        if (fa[U][k] != fa[V][k]) U = fa[U][k], V = fa[V][k];
    return fa[U][0];
}
signed main() {
    n = read();
    for (int i = 1, u, v; i <= n; ++i) {
        u = read(), v = read();
        if (v == -1)
            root = u;
        else
            son[u].push_back(v), son[v].push_back(u);
    }
    get_fa_bfs();
    m = read();
    for (int i = 1, u, v; i <= m; ++i, putchar('\n'))
        u = read(), v = read(),
        putchar(lca(u, v) == u ? '1' : (lca(u, v) == v ? '2' : '0'));
    return 0;
}
Code2:
#include <bits/stdc++.h>
using namespace std;
const int MA = 1 << 23;
char buf[MA], *p1 = buf, *p2 = buf;
#define gc()                                                            \
    (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, MA, stdin), p1 == p2) \
         ? EOF                                                          \
         : *p1++)
inline int read() {
    int ff = 0, rr = 1;
    char ch = gc();
    while (!isdigit(ch)) {
        if (ch == '-') rr = -1;
        ch = gc();
    }
    while (isdigit(ch)) {
        ff = (ff << 1) + (ff << 3) + (ch ^ 48);
        ch = gc();
    }
    return ff * rr;
}
void print(int x) {
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) print(x / 10);
    putchar(x % 10 + '0');
}
const int N = 500005;
int n, m, root, fa[N][25], dep[N], q[N];
vector<int> son[N];
inline void get_fa_bfs() {
    memset(dep, 0x3f, sizeof dep);
    int hh = 0, tt = 0;
    dep[root] = 1, dep[0] = 0, q[0] = root;
    while (hh <= tt) {
        int x = q[hh++];
        for (int i = 0; i < son[x].size(); ++i) {
            int y = son[x][i];
            if (dep[y] < dep[x] + 1) continue;
            dep[y] = dep[x] + 1, fa[y][0] = x, q[++tt] = y;
            for (int k = 1; (1 << k) <= n; ++k)
                fa[y][k] = fa[fa[y][k - 1]][k - 1];
        }
    }
}
inline int lca(int U, int V) {
    if (dep[U] < dep[V]) swap(U, V);
    for (int k = 15; ~k; --k)
        if (dep[fa[U][k]] >= dep[V]) U = fa[U][k];
    if (U == V) return U;
    for (int k = 15; ~k; --k)
        if (fa[U][k] != fa[V][k]) U = fa[U][k], V = fa[V][k];
    return fa[U][0];
}
signed main() {
    n = read(), m = read(), root = read();
    for (int i = 1, u, v; i < n; ++i)
        u = read(), v = read(), son[u].push_back(v), son[v].push_back(u);
    get_fa_bfs();
    for (int i = 1, u, v; i <= m; ++i, putchar('\n'))
        u = read(), v = read(), print(lca(u, v));
    return 0;
}
Code3:
#include <bits/stdc++.h>
using namespace std;
const int MA = 1 << 23;
char buf[MA], *p1 = buf, *p2 = buf;
#define gc()                                                            \
    (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, MA, stdin), p1 == p2) \
         ? EOF                                                          \
         : *p1++)
inline int read() {
    int ff = 0, rr = 1;
    char ch = gc();
    while (!isdigit(ch)) {
        if (ch == '-') rr = -1;
        ch = gc();
    }
    while (isdigit(ch)) {
        ff = (ff << 1) + (ff << 3) + (ch ^ 48);
        ch = gc();
    }
    return ff * rr;
}
void print(int x) {
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) print(x / 10);
    putchar(x % 10 + '0');
}
const int N = 300005;
int n, fa[N][20], dep[N], q[N], a[N], point[N];
vector<int> son[N];
inline void get_fa_bfs() {
    memset(dep, 0x3f, sizeof dep);
    int hh = 0, tt = 0;
    dep[0] = 0, dep[1] = 1, q[0] = 1;
    while (hh <= tt) {
        int x = q[hh++];
        for (int i = 0; i < son[x].size(); ++i) {
            int y = son[x][i];
            if (dep[y] < dep[x] + 1) continue;
            dep[y] = dep[x] + 1, fa[y][0] = x, q[++tt] = y;
            for (int k = 1; (1 << k) <= n; ++k)
                fa[y][k] = fa[fa[y][k - 1]][k - 1];
        }
    }
}
inline int lca(int U, int V) {
    if (dep[U] < dep[V]) swap(U, V);
    for (int k = 19; ~k; --k)
        if (dep[fa[U][k]] >= dep[V]) U = fa[U][k];
    if (U == V) return U;
    for (int k = 19; ~k; --k)
        if (fa[U][k] != fa[V][k]) U = fa[U][k], V = fa[V][k];
    return fa[U][0];
}
void get_sum_dfs(int x, int F) {
    for (int i = 0; i < son[x].size(); ++i) {
        int y = son[x][i];
        if (y == F) continue;
        get_sum_dfs(y, x);
        point[x] += point[y];
    }
}
signed main() {
    n = read();
    for (int i = 1, u, v; i <= n; ++i) a[i] = read();
    for (int i = 1, u, v; i < n; ++i)
        u = read(), v = read(), son[u].push_back(v), son[v].push_back(u);
    get_fa_bfs();
    for (int i = 2; i <= n; ++i)
        ++point[a[i - 1]], ++point[a[i]], --point[lca(a[i - 1], a[i])],
            --point[fa[lca(a[i - 1], a[i])][0]];
    get_sum_dfs(1, 0);
    for (int i = 2; i <= n; ++i) --point[a[i]];
    for (int i = 1; i <= n; ++i, putchar('\n')) print(point[i]);
    return 0;
}
Code4:
#include <bits/stdc++.h>
using namespace std;

const int N = 100005;
vector <int> G[N];
int n, m, ans, point[N];
int dep[N], fa[N][25], q[N];

inline void get_fa_bfs() {
    memset(dep, 0x3f, sizeof dep);
    int hh = 0, tt = 0;
    q[0] = 1, dep[0] = 0, dep[1] = 1;
    while (hh <= tt) {
        int x = q[hh ++];
        for (int i = 0; i < G[x].size(); ++i) {
            int y = G[x][i];
            if (dep[y] < dep[x] + 1) continue;
            dep[y] = dep[x] + 1;
            fa[y][0] = x, q[++ tt] = y;
            for (int k = 1; (1 << k) <= n; ++k)
                fa[y][k] = fa[fa[y][k - 1]][k - 1];
        }
    }
}

inline int lca(int U, int V) {
    if (dep[U] < dep[V]) swap(U, V);
    for (int k = 21; ~k; --k)
        if (dep[fa[U][k]] >= dep[V]) U = fa[U][k];
    if (U == V) return U;
    for (int k = 21; ~k; --k)
        if (fa[U][k] != fa[V][k]) U = fa[U][k], V = fa[V][k];
    return fa[U][0];
}

void get_sum_dfs(int x, int F) {
    for (int i = 0; i < G[x].size(); ++i) {
        int y = G[x][i];
        if (y == F) continue;
        get_sum_dfs(y, x);
        point[x] += point[y];
    }
    ans = max(ans, point[x]);
}

signed main() {
    ios::sync_with_stdio(0);cin.tie(0);cout.tie(0);

    cin >> n >> m;
    for (int i = 1, x, y; i < n; ++i) {
        cin >> x >> y;
        G[x].push_back(y), G[y].push_back(x);
    }
    get_fa_bfs();
    for (int i = 1, x, y; i <= m; ++i) {
        cin >> x >> y;
        int t = lca(x, y);
        ++ point[x], ++ point[y];
        -- point[t], -- point[fa[t][0]];
    }
    get_sum_dfs(1, 1);
    cout << ans;

    return 0;
}
Code5:
#include <bits/stdc++.h>
using namespace std;
const int MA = 1 << 23;
char buf[MA], *p1 = buf, *p2 = buf;
#define gc()                                                            \
    (p1 == p2 && (p2 = (p1 = buf) + fread(buf, 1, MA, stdin), p1 == p2) \
         ? EOF                                                          \
         : *p1++)
inline int read() {
    int ff = 0, rr = 1;
    char ch = gc();
    while (!isdigit(ch)) {
        if (ch == '-') rr = -1;
        ch = gc();
    }
    while (isdigit(ch)) {
        ff = (ff << 1) + (ff << 3) + (ch ^ 48);
        ch = gc();
    }
    return ff * rr;
}
void print(int x) {
    if (x < 0) putchar('-'), x = -x;
    if (x > 9) print(x / 10);
    putchar(x % 10 + '0');
}
const int N = 200005;
int n, m, point[N], edge[N], fa[N][20], dep[N], q[N], u[N], v[N], d[N];
vector<int> son[N];
inline void get_fa_bfs() {
    memset(dep, 0x3f, sizeof dep);
    int hh = 0, tt = 0;
    dep[0] = 0, dep[1] = 1, q[0] = 1;
    while (hh <= tt) {
        int x = q[hh++];
        for (int i = 0; i < son[x].size(); ++i) {
            int y = son[x][i];
            if (dep[y] < dep[x] + 1) continue;
            dep[y] = dep[x] + 1, fa[y][0] = x, q[++tt] = y;
            for (int k = 1; (1 << k) <= n; ++k)
                fa[y][k] = fa[fa[y][k - 1]][k - 1];
        }
    }
}
inline int lca(int U, int V) {
    if (dep[U] < dep[V]) swap(U, V);
    for (int k = 19; ~k; --k)
        if (dep[fa[U][k]] >= dep[V]) U = fa[U][k];
    if (U == V) return U;
    for (int k = 19; ~k; --k)
        if (fa[U][k] != fa[V][k]) U = fa[U][k], V = fa[V][k];
    return fa[U][0];
}
void get_sum_dfs(int x, int F) {
    for (int i = 0; i < son[x].size(); ++i) {
        int y = son[x][i];
        if (y == F) continue;
        get_sum_dfs(y, x);
        point[x] += point[y], edge[x] += edge[y];
    }
}
void get_ans_dfs(int x, int F) {
    for (int i = 0; i < son[x].size(); ++i) {
        int y = son[x][i];
        if (y == F) continue;
        point[y] += point[x], edge[y] += edge[x];
        get_ans_dfs(y, x);
    }
}
signed main() {
    n = read(), m = read();
    for (int i = 1, U, V; i < n; ++i)
        U = read(), V = read(), son[U].push_back(V), son[V].push_back(U);
    get_fa_bfs();
    for (int i = 1; i <= m; ++i)
        d[i] = lca(u[i] = read(), v[i] = read()), ++point[u[i]], ++point[v[i]],
        --point[d[i]], --point[fa[d[i]][0]], --edge[u[i]], --edge[v[i]],
        edge[d[i]] += 2;
    get_sum_dfs(1, 0), get_ans_dfs(1, 0);
    point[0] = edge[0] = 0;
    for (int i = 1; i <= m; ++i, putchar('\n'))
        print(point[u[i]] + point[v[i]] - point[d[i]] - point[fa[d[i]][0]] +
              edge[u[i]] + edge[v[i]] - edge[d[i]] * 2 - 1);
    return 0;
}
Code6: