洛谷 P4211 [LNOI2014]LCA
題意
給出一個 \(n\) 個節點的有根樹(編號為 \(0\) 到 \(n-1\),根節點為 \(0\))。
一個點的深度定義為這個節點到根的距離 \(+1\),記為 \(dep[i]\)
用 \(LCA(i,j)\) 表示 \(i\) 與 \(j\) 的最近公共祖先。
有 \(q\) 次詢問,每次詢問給出 \(l\ r\ z\),求\[\sum\limits_{i=l}^r dep[\text{LCA}(i,z)] \]
思路
樹鏈剖分
不喜歡從 \(0\) 開始編號,因此這裡從 \(1\) 開始編號,輸入的點都要 \(+1\)。
先單獨考慮一次查詢,我們發現 \(dep[i]\) 其實就是 \(i\)
擴充套件到多個詢問的情況,這個時候每次執行上面的操作就顯得很不對。再看一下原來的式子:
\[\sum\limits_{i=l}^r dep[\text{LCA}(i,z)] \]
稍加思考我們就可以發現,這個式子可以分成兩個部分的差,也就是:
\(\sum\limits_{i=1}^r dep[\text{LCA}(i,z)]-\sum\limits_{i=1}^{l-1} dep[\text{LCA}(i,z)]\)
因為詢問並不是強制線上的,所以我們可以將詢問離線,把每一個詢問分為 \(1\sim l-1\) 和 \(1\sim r\) 兩個小詢問, \(1\sim l-1\) 的詢問的 \(tag\) 標記為 \(0\),表示我們要減去這個值。 \(1\sim r\) 的詢問的 \(tag\) 標記為 \(1\),表示要加上這個值。
將詢問按照端點進行排序,從小到大進行處理,如果當前節點小於詢問的右端點,就一直加 \(now\) 到根節點的點權,不斷增大 \(now\),直到 \(now=\) 詢問端點,計算出當前詢問 \(z\) 到根節點的值,然後用 \(tag\) 判斷是加還是減。
上述操作可以用樹剖套線段樹實現,如果學過的話應該很容易理解。
時間複雜度 \(O(q\log^2 n)\)。
程式碼
/*
Name: P4211 [LNOI2014]LCA
Author: Loceaner
Date: 07/09/20 08:58
Description:
Debug: dfn[x]=++tot寫成dfn[tp]=++tot
線段樹中update忘記return
*/
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;
const int A = 2e5 + 11;
const int B = 1e6 + 11;
const int mod = 201314;
const int inf = 0x3f3f3f3f;
inline int read() {
char c = getchar();
int x = 0, f = 1;
for ( ; !isdigit(c); c = getchar()) if (c == '-') f = -1;
for ( ; isdigit(c); c = getchar()) x = x * 10 + (c ^ 48);
return x * f;
}
struct node { int to, nxt; } e[A];
int n, q, qcnt, cnt, head[A], ans[A];
struct ques { int r, z, num, flag; } a[A];
int tot, fa[A], siz[A], son[A], dfn[A], top[A], dep[A]; //樹剖
inline void add(int from, int to) {
e[++cnt].to = to;
e[cnt].nxt = head[from];
head[from] = cnt;
}
bool cmp(ques a, ques b) {
return a.r < b.r;
}
namespace Seg {
#define lson rt << 1
#define rson rt << 1 | 1
struct tree { int l, r, sum, lazy; } t[A << 2];
inline void pushup(int rt) {
t[rt].sum = (t[lson].sum + t[rson].sum) % mod;
}
inline void pushdown(int rt) {
t[lson].sum += (t[lson].r - t[lson].l + 1) * t[rt].lazy, t[lson].sum %= mod;
t[rson].sum += (t[rson].r - t[rson].l + 1) * t[rt].lazy, t[rson].sum %= mod;
t[lson].lazy += t[rt].lazy, t[lson].lazy %= mod;
t[rson].lazy += t[rt].lazy, t[rson].lazy %= mod;
t[rt].lazy = 0;
}
void build(int rt, int l, int r) {
t[rt].l = l, t[rt].r = r;
if (l == r) return;
int mid = (l + r) >> 1;
build(lson, l, mid), build(rson, mid + 1, r);
pushup(rt);
}
void update(int rt, int l, int r, int x) {
if (l <= t[rt].l && t[rt].r <= r) {
t[rt].sum += (t[rt].r - t[rt].l + 1) * x, t[rt].sum %= mod;
t[rt].lazy += x, t[rt].lazy %= mod;
return;
}
if (t[rt].lazy) pushdown(rt);
int mid = (t[rt].l + t[rt].r) >> 1;
if (l <= mid) update(lson, l, r, x);
if (r > mid) update(rson, l, r, x);
pushup(rt);
}
int query(int rt, int l, int r) {
if (l <= t[rt].l && t[rt].r <= r) return t[rt].sum;
if (t[rt].lazy) pushdown(rt);
int mid = (t[rt].l + t[rt].r) >> 1, ans = 0;
if (l <= mid) ans += query(lson, l, r);
if (r > mid) ans += query(rson, l, r);
return ans % mod;
}
}
void prepare(int x, int f) {
fa[x] = f, siz[x] = 1, dep[x] = dep[f] + 1;
for (int i = head[x]; i; i = e[i].nxt) {
int to = e[i].to;
if (to == f) continue;
prepare(to, x), siz[x] += siz[to];
if (siz[to] > siz[son[x]]) son[x] = to;
}
}
void dfs(int x, int tp) {
top[x] = tp, dfn[x] = ++tot;
if (son[x]) dfs(son[x], tp);
for (int i = head[x]; i; i = e[i].nxt) {
int to = e[i].to;
if (to == fa[x] || to == son[x]) continue;
dfs(to, to);
}
}
inline void add_val(int x, int y) {
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
Seg::update(1, dfn[top[x]], dfn[x], 1);
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
Seg::update(1, dfn[x], dfn[y], 1);
return;
}
inline int ask_val(int x, int y) {
int ans = 0;
while (top[x] != top[y]) {
if (dep[top[x]] < dep[top[y]]) swap(x, y);
ans += Seg::query(1, dfn[top[x]], dfn[x]);
x = fa[top[x]];
}
if (dep[x] > dep[y]) swap(x, y);
ans += Seg::query(1, dfn[x], dfn[y]);
return ans;
}
int main() {
n = read(), q = read();
for (int i = 2; i <= n; i++) {
fa[i] = read() + 1;
add(fa[i], i), add(i, fa[i]);
}
prepare(1, 0), dfs(1, 1), Seg::build(1, 1, n);
for (int i = 1; i <= q; i++) {
int l = read() + 1, r = read() + 1, z = read() + 1;
a[++qcnt] = (ques) {l - 1, z, i, 0};
a[++qcnt] = (ques) {r, z, i, 1};
}
sort(a + 1, a + 1 + qcnt, cmp);
int now = 1;
for (int i = 1; i <= qcnt; i++) {
while (now <= a[i].r) add_val(1, now++);
if (a[i].flag == 1) ans[a[i].num] += ask_val(1, a[i].z);
else ans[a[i].num] -= ask_val(1, a[i].z);
ans[a[i].num] += mod, ans[a[i].num] %= mod;
}
for (int i = 1; i <= q; i++) cout << ans[i] << '\n';
return 0;
}