資料結構專題-專項訓練:樹鏈剖分
1. 前言
本篇博文為樹鏈剖分的演算法總結與專題訓練。
沒有學過樹鏈剖分?
傳送門:演算法學習筆記:樹鏈剖分
樹剖作為一種工具,可以有效解決各類樹上問題。
需要注意的是,藉助資料結構維護重鏈資訊的時候,不一定只是使用線段樹,平衡樹,分塊等等都可以使用。
當然這篇博文都是線段樹。
在往下看之前,請先確保學習過可持久化線段樹/動態開點線段樹/主席樹。
沒有學過可持久化線段樹/動態開點線段樹/主席樹?
傳送門:線段樹演算法總結&專題訓練4(可持久化線段樹/主席樹)
2. 題單
P3313 [SDOI2014]旅行
樹剖板子題。
考慮對這棵樹樹剖之後,使用線段樹來維護和與最大值的資訊,但是怎麼維護呢?
開 \(v\) 棵線段樹唄!然後在每一棵線段樹中維護對應的值,需要的時候就在對應線段樹中修改查詢。
結果空間複雜度為 \(O(4 \times n^2)\),只聽到一聲慘叫:“我 MLE 了!”
於是我們需要加一點優化。
學過可持久化線段樹的讀者就會發現這其實可以使用可持久化線段樹。
我們初始時開 \(v\) 棵線段樹,但是每棵線段樹都只有一個根節點,記做 \(root_i\)
然後需要的時候靜態開點即可。
空間複雜度:每次只會增加 \(\log n\) 個節點,則為可持久化線段樹的空間複雜度,\(O(20n)\)。
程式碼:
/* ========= Plozia ========= Author:Plozia Problem:P3313 [SDOI2014]旅行 Date:2021/3/9 ========= Plozia ========= */ #include <bits/stdc++.h> #define Max(a, b) ((a > b) ? a : b) using std::vector; using std::string; typedef long long LL; const int MAXN = 1e5 + 10, MAXX = 2e6 + 10; int n, q, wfir[MAXN], cfir[MAXN], w[MAXN], c[MAXN], root[MAXN], cnt; int Son[MAXN], Size[MAXN], dep[MAXN], fa[MAXN], Top[MAXN], id[MAXN]; vector <int> Next[MAXN]; struct node { int w, l, r, maxn; }tree[MAXN + MAXX + 10]; int read() { int sum = 0, fh = 1; char ch = getchar(); for (; ch < '0' || ch > '9'; ch = getchar()) fh -= (ch == '-') << 1; for (; ch >= '0' && ch <= '9'; ch = getchar()) sum = (sum << 3) + (sum << 1) + (ch ^ 48); return (fh == 1) ? sum : -sum; } namespace Segment_tree { void change(int p, int l, int r, int k, int v, int w) { if (l == r) {tree[p].w = tree[p].maxn = w; return ;} int mid = (l + r) >> 1; if (k <= mid) { if (!tree[p].l) tree[p].l = ++cnt; change(tree[p].l, l, mid, k, v, w); } else { if (!tree[p].r) tree[p].r = ++cnt; change(tree[p].r, mid + 1, r, k, v, w); } tree[p].w = tree[tree[p].l].w + tree[tree[p].r].w; tree[p].maxn = Max(tree[tree[p].l].maxn, tree[tree[p].r].maxn); } int ask_sum(int p, int l1, int r1, int l2, int r2) { if (l1 >= l2 && r1 <= r2) return tree[p].w; int mid = (l1 + r1) >> 1, ans = 0; if (l2 <= mid) { if (tree[p].l) ans += ask_sum(tree[p].l, l1, mid, l2, r2); } if (r2 > mid) { if (tree[p].r) ans += ask_sum(tree[p].r, mid + 1, r1, l2, r2); } return ans; } int ask_max(int p, int l1, int r1, int l2, int r2) { if (l1 >= l2 && r1 <= r2) return tree[p].maxn; int mid = (l1 + r1) >> 1, ans = 0; if (l2 <= mid) { if (tree[p].l) ans = std::max(ans, ask_max(tree[p].l, l1, mid, l2, r2)); } if (r2 > mid) { if (tree[p].r) ans = std::max(ans, ask_max(tree[p].r, mid + 1, r1, l2, r2)); } return ans; } } void dfs1(int now, int father, int depth) { dep[now] = depth; fa[now] = father; Size[now] = 1; for (int i = 0; i < Next[now].size(); ++i) { int u = Next[now][i]; if (u == father) continue; dfs1(u, now, depth + 1); Size[now] += Size[u]; if (Size[u] > Size[Son[now]]) Son[now] = u; } } void dfs2(int now, int top_father) { id[now] = ++cnt; Top[now] = top_father; w[cnt] = wfir[now]; c[cnt] = cfir[now]; if (!Son[now]) return ; dfs2(Son[now], top_father); for (int i = 0; i < Next[now].size(); ++i) { int u = Next[now][i]; if (u == fa[now] || u == Son[now]) continue; dfs2(u, u); } } int main() { n = read(), q = read(); for (int i = 1; i <= n; ++i) wfir[i] = read(), cfir[i] = read(); for (int i = 1; i < n; ++i) { int x = read(), y = read(); Next[x].push_back(y), Next[y].push_back(x); } dfs1(1, 1, 1); dfs2(1, 1); for (int i = 1; i <= n; ++i) root[i] = i; cnt = n; for (int i = 1; i <= n; ++i) Segment_tree::change(root[c[i]], 1, n, i, c[i], w[i]); for (int i = 1; i <= q; ++i) { string str; std::cin >> str; if (str == "CC") { int x = read(), c_ = read(); Segment_tree::change(root[c[id[x]]], 1, n, id[x], c[id[x]], 0); c[id[x]] = c_; Segment_tree::change(root[c[id[x]]], 1, n, id[x], c[id[x]], w[id[x]]); } if (str == "CW") int x = read(), w_ = read(); w[id[x]] = w_; Segment_tree::change(root[c[id[x]]], 1, n, id[x], c[id[x]], w[id[x]]); } if (str == "QS") { int x = read(), y = read(); int ans = 0, c_ = c[id[x]]; while (Top[x] != Top[y]) { if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y); ans += Segment_tree::ask_sum(root[c_], 1, n, id[Top[x]], id[x]); x = fa[Top[x]]; } if (dep[x] > dep[y]) std::swap(x, y); ans += Segment_tree::ask_sum(root[c_], 1, n, id[x], id[y]); printf("%d\n", ans); } if (str == "QM") { int x = read(), y = read(); int ans = 0, c_ = c[id[x]]; while (Top[x] != Top[y]) { if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y); ans = std::max(ans, Segment_tree::ask_max(root[c_], 1, n, id[Top[x]], id[x])); x = fa[Top[x]]; } if (dep[x] > dep[y]) std::swap(x, y); ans = std::max(ans, Segment_tree::ask_max(root[c_], 1, n, id[x], id[y])); printf("%d\n", ans); } } return 0; }
P2486 [SDOI2011]染色
這道題是一道細節題。
解題思路還是比較明顯的,使用線段樹維護一下區間的頭元素,尾元素,答案,合併的時候注意一下頭尾元素的合併即可。
然後樹剖呢?直接剖一下,然後跳就可以了呀,注意臨界點的答案合併。
然後開始愉快的碼碼碼,然後……調了 3 個小時。
所以這道題細節到底在哪裡呢?
- 注意線段樹 \(update\) 的時候左兒子與右兒子可能會首尾相同。
- 如果你只是寫了一個 \(ask\) 函式而且這個函式只返回了區間的答案,那麼請注意:
在樹剖的時候當我們完成詢問 \([id_{Top_x},id_x]\) 的時候,一定要知道 \(id_{Top_x}\) 和 \(id_{fa_{Top_x}}\) 的顏色是否相同,因為這涉及到答案是否要減一。相同則需要減一,防止後面的詢問對這次產生干擾。
程式碼:
/*
========= Plozia =========
Author:Plozia
Problem:P2486 [SDOI2011]染色
Date:2021/3/9
========= Plozia =========
*/
#include <bits/stdc++.h>
using std::vector;
typedef long long LL;
const int MAXN = 1e5 + 10;
int n, m, wfir[MAXN];
int Size[MAXN], Son[MAXN], dep[MAXN], fa[MAXN], Top[MAXN], id[MAXN], w[MAXN], cnt;
vector <int> Next[MAXN];
struct node
{
int firnum, lasnum;
int sum, l, r, add;
#define l(p) tree[p].l
#define r(p) tree[p].r
#define s(p) tree[p].sum
#define a(p) tree[p].add
#define fir(p) tree[p].firnum
#define las(p) tree[p].lasnum
}tree[MAXN << 2];
int read()
{
int sum = 0, fh = 1; char ch = getchar();
for (; ch < '0' || ch > '9'; ch = getchar()) fh -= (ch == '-') << 1;
for (; ch >= '0' && ch <= '9'; ch = getchar()) sum = (sum << 3) + (sum << 1) + (ch ^ 48);
return (fh == 1) ? sum : -sum;
}
namespace Segment_tree
{
void build(int p, int l, int r)
{
l(p) = l, r(p) = r;
if (l == r) {s(p) = 1; fir(p) = las(p) = w[l]; return ;}
int mid = (l + r) >> 1;
build(p << 1, l, mid); build(p << 1 | 1, mid + 1, r);
s(p) = s(p << 1) + s(p << 1 | 1);
if (las(p << 1) == fir(p << 1 | 1)) --s(p);
fir(p) = fir(p << 1), las(p) = las(p << 1 | 1);
}
void spread(int p)
{
if (a(p))
{
s(p << 1) = s(p << 1 | 1) = 1;
a(p << 1) = a(p << 1 | 1) = a(p);
fir(p << 1) = fir(p << 1 | 1) = las(p << 1) = las(p << 1 | 1) = a(p);
a(p) = 0;
}
}
void change(int p, int l, int r, int c)
{
if (l(p) >= l && r(p) <= r) {s(p) = 1; a(p) = fir(p) = las(p) = c; return ;}
spread(p);
int mid = (l(p) + r(p)) >> 1;
if (l <= mid) change(p << 1, l, r, c);
if (r > mid) change(p << 1 | 1, l, r, c);
s(p) = s(p << 1) + s(p << 1 | 1);
if (las(p << 1) == fir(p << 1 | 1)) --s(p);
fir(p) = fir(p << 1), las(p) = las(p << 1 | 1);
}
int ask(int p, int l, int r)
{
if (l(p) >= l && r(p) <= r) return s(p);
spread(p);
int mid = (l(p) + r(p)) >> 1, ans = 0;
if (l <= mid && r > mid)
{
ans = ask(p << 1, l, r) + ask(p << 1 | 1, l, r);
if (las(p << 1) == fir(p << 1 | 1)) --ans;
}
else if (l <= mid) ans = ask(p << 1, l, r);
else if (r > mid) ans = ask(p << 1 | 1, l, r);
return ans;
}
int ask2(int p, int k)
{
if (l(p) == r(p) && r(p) == k) return fir(p);
spread(p);
int mid = (l(p) + r(p)) >> 1;
if (k <= mid) return ask2(p << 1, k);
else return ask2(p << 1 | 1, k);
}
}
void dfs1(int now, int father, int depth)
{
dep[now] = depth; fa[now] = father; Size[now] = 1;
for (int i = 0; i < Next[now].size(); ++i)
{
int u = Next[now][i];
if (u == father) continue;
dfs1(u, now, depth + 1);
Size[now] += Size[u];
if (Size[u] > Size[Son[now]]) Son[now] = u;
}
}
void dfs2(int now, int top_father)
{
id[now] = ++cnt; Top[now] = top_father; w[cnt] = wfir[now];
if (!Son[now]) return ;
dfs2(Son[now], top_father);
for (int i = 0; i < Next[now].size(); ++i)
{
int u = Next[now][i];
if (u == fa[now] || u == Son[now]) continue;
dfs2(u, u);
}
}
int main()
{
n = read(), m = read();
for (int i = 1; i <= n; ++i) wfir[i] = read();
for (int i = 1; i < n; ++i)
{
int x = read(), y = read();
Next[x].push_back(y), Next[y].push_back(x);
}
dfs1(1, 1, 1); dfs2(1, 1); Segment_tree::build(1, 1, n);
for (int i = 1; i <= m; ++i)
{
char ch; std::cin >> ch;
if (ch == 'C')
{
int a = read(), b = read(), c = read();
while (Top[a] != Top[b])
{
if (dep[Top[a]] < dep[Top[b]]) std::swap(a, b);
Segment_tree::change(1, id[Top[a]], id[a], c);
a = fa[Top[a]];
}
if (dep[a] > dep[b]) std::swap(a, b);
Segment_tree::change(1, id[a], id[b], c);
}
if (ch == 'Q')
{
int a = read(), b = read(), ans = 0;
while (Top[a] != Top[b])
{
if (dep[Top[a]] < dep[Top[b]]) std::swap(a, b);
ans += Segment_tree::ask(1, id[Top[a]], id[a]);
if (Segment_tree::ask2(1, id[Top[a]]) == Segment_tree::ask2(1, id[fa[Top[a]]])) --ans;
a = fa[Top[a]];
}
if (dep[a] > dep[b]) std::swap(a, b);
ans += Segment_tree::ask(1, id[a], id[b]);
printf("%d\n", ans);
}
}
return 0;
}
P1505 [國家集訓隊]旅遊
也是一道樹剖題。
這道題首先需要『邊權轉點權』。
邊權轉點權的方式如下:
對於第 \(i\) 條邊 \((u,v,w)\),我們取深度較大的這個點,假設為 \(x\),則 \(x\) 的點權為 \(w\)。
這樣,除根節點之外,每一個點都均勻有一個點權。
然後就可以愉快的樹剖啦!
線段樹需要注意的是,一個區間取反兩次就是沒有取反,程式碼中我採用的是異或的性質來處理。
還有,當 \(x,y\) 跳完,在同一條重鏈的時候,需要特判一下 \(x\) 是否等於 \(y\),因為 \(x,y\) 中深度較小的點記錄的點權是不算在路徑上的(為其與父節點的路徑長度),需要過濾。
細節還是很多的,程式碼量也很大。
程式碼:
/*
========= Plozia =========
Author:Plozia
Problem:P1505 [國家集訓隊]旅遊
Date:2021/3/12
========= Plozia =========
*/
#include <bits/stdc++.h>
#define Max(a, b) ((a > b) ? a : b)
#define Min(a, b) ((a < b) ? a : b)
using std::vector;
using std::string;
typedef long long LL;
const int MAXN = 2e5 + 10;
int n, m, afir[MAXN];
int cnt, Top[MAXN], id[MAXN], fa[MAXN], dep[MAXN], Size[MAXN], Son[MAXN], a[MAXN];
struct Edge
{
int x, y, z;
}e[MAXN];
struct node
{
int l, r, sum, maxn, minn, add;
#define l(p) tree[p].l
#define r(p) tree[p].r
#define s(p) tree[p].sum
#define maxn(p) tree[p].maxn
#define minn(p) tree[p].minn
#define a(p) tree[p].add
}tree[MAXN << 2];
vector <int> Next[MAXN], Num[MAXN];
int read()
{
int sum = 0, fh = 1; char ch = getchar();
for (; ch < '0' || ch > '9'; ch = getchar()) fh -= (ch == '-') << 1;
for (; ch >= '0' && ch <= '9'; ch = getchar()) sum = (sum << 3) + (sum << 1) + (ch ^ 48);
return (fh == 1) ? sum : -sum;
}
namespace Segment_tree
{
void build(int p, int l, int r)
{
l(p) = l, r(p) = r;
if (l == r) {s(p) = maxn(p) = minn(p) = a[l]; return ;}
int mid = (l + r) >> 1;
build(p << 1, l, mid); build(p << 1 | 1, mid + 1, r);
s(p) = s(p << 1) + s(p << 1 | 1);
maxn(p) = Max(maxn(p << 1), maxn(p << 1 | 1));
minn(p) = Min(minn(p << 1), minn(p << 1 | 1));
}
void spread(int p)
{
if (a(p) != 0)
{
a(p << 1) ^= 1; a(p << 1 | 1) ^= 1;
s(p << 1) *= -1; s(p << 1 | 1) *= -1;
int Maxn, Minn;
Maxn = maxn(p << 1), Minn = minn(p << 1);
minn(p << 1) = -Maxn, maxn(p << 1) = -Minn;
Maxn = maxn(p << 1 | 1), Minn = minn(p << 1 | 1);
minn(p << 1 | 1) = -Maxn, maxn(p << 1 | 1) = -Minn;
a(p) = 0;
}
}
void change_1(int p, int k, int w)
{
if (l(p) == r(p) && l(p) == k) {s(p) = maxn(p) = minn(p) = w; return ;}
spread(p);
int mid = (l(p) + r(p)) >> 1;
if (k <= mid) change_1(p << 1, k, w);
else change_1(p << 1 | 1, k, w);
s(p) = s(p << 1) + s(p << 1 | 1);
maxn(p) = Max(maxn(p << 1), maxn(p << 1 | 1));
minn(p) = Min(minn(p << 1), minn(p << 1 | 1));
}
void change_2(int p, int l, int r)
{
if (l(p) >= l && r(p) <= r)
{
a(p) ^= 1; s(p) *= -1;
int fir = maxn(p), sec = minn(p);
minn(p) = fir * -1, maxn(p) = sec * -1; return ;
}
spread(p); int mid = (l(p) + r(p)) >> 1;
if (l <= mid) change_2(p << 1, l, r);
if (r > mid) change_2(p << 1 | 1, l, r);
s(p) = s(p << 1) + s(p << 1 | 1);
maxn(p) = Max(maxn(p << 1), maxn(p << 1 | 1));
minn(p) = Min(minn(p << 1), minn(p << 1 | 1));
}
int ask_sum(int p, int l, int r)
{
if (l(p) >= l && r(p) <= r) return s(p);
spread(p); int mid = (l(p) + r(p)) >> 1, ans = 0;
if (l <= mid) ans += ask_sum(p << 1, l, r);
if (r > mid) ans += ask_sum(p << 1 | 1, l, r);
return ans;
}
int ask_maxn(int p, int l, int r)
{
if (l(p) >= l && r(p) <= r) return maxn(p);
spread(p); int mid = (l(p) + r(p)) >> 1, ans = -0x7f7f7f7f, tmp = -0x7f7f7f7f;
if (l <= mid) {tmp = ask_maxn(p << 1, l, r); ans = Max(ans, tmp);}
if (r > mid) {tmp = ask_maxn(p << 1 | 1, l, r); ans = Max(ans, tmp);}
return ans;
}
int ask_minn(int p, int l, int r)
{
if (l(p) >= l && r(p) <= r) return minn(p);
spread(p); int mid = (l(p) + r(p)) >> 1, ans = 0x7f7f7f7f, tmp = 0x7f7f7f7f;
if (l <= mid) {tmp = ask_minn(p << 1, l, r); ans = Min(ans, tmp);}
if (r > mid) {tmp = ask_minn(p << 1 | 1, l, r); ans = Min(ans, tmp);}
return ans;
}
}
void dfs1(int now, int father, int depth)
{
dep[now] = depth; fa[now] = father; Size[now] = 1;
for (int i = 0; i < Next[now].size(); ++i)
{
int u = Next[now][i];
if (u == father) continue;
dfs1(u, now, depth + 1);
Size[now] += Size[u];
if (Size[u] > Size[Son[now]]) Son[now] = u;
}
}
void dfs2(int now, int top_father)
{
Top[now] = top_father; id[now] = ++cnt; a[cnt] = afir[now];
if (!Son[now]) return ; dfs2(Son[now], top_father);
for (int i = 0; i < Next[now].size(); ++i)
{
int u = Next[now][i];
if (u == fa[now] || u == Son[now]) continue ;
dfs2(u, u);
}
}
int main()
{
n = read();
for (int i = 1; i < n; ++i)
{
int x = read() + 1, y = read() + 1, z = read();
e[i] = (Edge){x, y, z};
Next[x].push_back(y), Next[y].push_back(x);
Num[x].push_back(z), Num[y].push_back(z);
}
dfs1(1, 1, 1);
for (int i = 1; i < n; ++i)
{
int x = e[i].x, y = e[i].y, z = e[i].z;
if (dep[x] > dep[y]) afir[x] = z;
else afir[y] = z;
}//邊權轉點權
dfs2(1, 1); Segment_tree::build(1, 1, n);
m = read();
for (int i = 1; i <= m; ++i)
{
string str; std::cin >> str;
if (str == "C")
{
int k = read(), w = read();
Segment_tree::change_1(1, id[(dep[e[k].x] > dep[e[k].y]) ? e[k].x : e[k].y], w);
}
if (str == "N")
{
int x = read() + 1, y = read() + 1;
while (Top[x] != Top[y])
{
if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y);
Segment_tree::change_2(1, id[Top[x]], id[x]);
x = fa[Top[x]];
}
if (dep[x] > dep[y]) std::swap(x, y);
if (x != y) Segment_tree::change_2(1, id[x] + 1, id[y]);
}
if (str == "SUM")
{
int x = read() + 1, y = read() + 1, ans = 0;
while (Top[x] != Top[y])
{
if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y);
ans += Segment_tree::ask_sum(1, id[Top[x]], id[x]);
x = fa[Top[x]];
}
if (dep[x] > dep[y]) std::swap(x, y);
if (x != y) ans += Segment_tree::ask_sum(1, id[x] + 1, id[y]);
printf("%d\n", ans);
}
if (str == "MAX")
{
int x = read() + 1, y = read() + 1, ans = -0x7f7f7f7f, tmp = -0x7f7f7f7f;
while (Top[x] != Top[y])
{
if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y);
tmp = Segment_tree::ask_maxn(1, id[Top[x]], id[x]); ans = Max(ans, tmp);
x = fa[Top[x]];
}
if (dep[x] > dep[y]) std::swap(x, y);
if (x != y) {tmp = Segment_tree::ask_maxn(1, id[x] + 1, id[y]); ans = Max(ans, tmp);}
printf("%d\n", (ans == -0x7f7f7f7f) ? 0 : ans);
}
if (str == "MIN")
{
int x = read() + 1, y = read() + 1, ans = 0x7f7f7f7f, tmp = 0x7f7f7f7f;
while (Top[x] != Top[y])
{
if (dep[Top[x]] < dep[Top[y]]) std::swap(x, y);
tmp = Segment_tree::ask_minn(1, id[Top[x]], id[x]); ans = Min(ans, tmp);
x = fa[Top[x]];
}
if (dep[x] > dep[y]) std::swap(x, y);
if (x != y) {tmp = Segment_tree::ask_minn(1, id[x] + 1, id[y]); ans = Min(ans, tmp);}
printf("%d\n", (ans == 0x7f7f7f7f) ? 0 : ans);
}
}
return 0;
}
3. 總結
樹剖的題目難點還是在維護重鏈資訊上,樹剖本身不是特別難,在面對樹剖題目的時候我們可以假設問題是一個序列問題,在確定怎麼維護之後套上樹剖即可。