1. 程式人生 > 實用技巧 >洛谷 P4211 [LNOI2014]LCA

洛谷 P4211 [LNOI2014]LCA

題意

給出一個 \(n\) 個節點的有根樹(編號為 \(0\)\(n-1\),根節點為 \(0\))。
一個點的深度定義為這個節點到根的距離 \(+1\),記為 \(dep[i]\)
\(LCA(i,j)\) 表示 \(i\)\(j\) 的最近公共祖先。
\(q\) 次詢問,每次詢問給出 \(l\ r\ z\),求

\[\sum\limits_{i=l}^r dep[\text{LCA}(i,z)] \]

思路

樹鏈剖分

不喜歡從 \(0\) 開始編號,因此這裡從 \(1\) 開始編號,輸入的點都要 \(+1\)

先單獨考慮一次查詢,我們發現 \(dep[i]\) 其實就是 \(i\)

點到根節點的點數(包括自己),所以單獨一次查詢時所做的操作就是將區間 \([l,r]\) 中的點到根節點的點上的權值加一,然後查詢 \(z\) 到根節點的權值和,可能有些難理解,下面畫圖理解一下。

擴充套件到多個詢問的情況,這個時候每次執行上面的操作就顯得很不對。再看一下原來的式子:

\[\sum\limits_{i=l}^r dep[\text{LCA}(i,z)] \]

稍加思考我們就可以發現,這個式子可以分成兩個部分的差,也就是:

\(\sum\limits_{i=1}^r dep[\text{LCA}(i,z)]-\sum\limits_{i=1}^{l-1} dep[\text{LCA}(i,z)]\)

因為詢問並不是強制線上的,所以我們可以將詢問離線,把每一個詢問分為 \(1\sim l-1\)\(1\sim r\) 兩個小詢問, \(1\sim l-1\) 的詢問的 \(tag\) 標記為 \(0\),表示我們要減去這個值。 \(1\sim r\) 的詢問的 \(tag\) 標記為 \(1\),表示要加上這個值。

將詢問按照端點進行排序,從小到大進行處理,如果當前節點小於詢問的右端點,就一直加 \(now\) 到根節點的點權,不斷增大 \(now\),直到 \(now=\) 詢問端點,計算出當前詢問 \(z\) 到根節點的值,然後用 \(tag\) 判斷是加還是減。

上述操作可以用樹剖套線段樹實現,如果學過的話應該很容易理解。

時間複雜度 \(O(q\log^2 n)\)

程式碼

/*
  Name: P4211 [LNOI2014]LCA
  Author: Loceaner
  Date: 07/09/20 08:58
  Description: 
  Debug: dfn[x]=++tot寫成dfn[tp]=++tot
         線段樹中update忘記return 
*/
#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
using namespace std;

const int A = 2e5 + 11;
const int B = 1e6 + 11;
const int mod = 201314;
const int inf = 0x3f3f3f3f;

inline int read() {
  char c = getchar();
  int x = 0, f = 1;
  for ( ; !isdigit(c); c = getchar()) if (c == '-') f = -1;
  for ( ; isdigit(c); c = getchar()) x = x * 10 + (c ^ 48);
  return x * f;
}

struct node { int to, nxt; } e[A];
int n, q, qcnt, cnt, head[A], ans[A];
struct ques { int r, z, num, flag; } a[A];
int tot, fa[A], siz[A], son[A], dfn[A], top[A], dep[A]; //樹剖 

inline void add(int from, int to) {
  e[++cnt].to = to;
  e[cnt].nxt = head[from];
  head[from] = cnt;
}

bool cmp(ques a, ques b) {
  return a.r < b.r;
}

namespace Seg {
  #define lson rt << 1
  #define rson rt << 1 | 1
  struct tree { int l, r, sum, lazy; } t[A << 2];
  inline void pushup(int rt) {
    t[rt].sum = (t[lson].sum + t[rson].sum) % mod;
  }
  inline void pushdown(int rt) {
    t[lson].sum += (t[lson].r - t[lson].l + 1) * t[rt].lazy, t[lson].sum %= mod;
    t[rson].sum += (t[rson].r - t[rson].l + 1) * t[rt].lazy, t[rson].sum %= mod;
    t[lson].lazy += t[rt].lazy, t[lson].lazy %= mod;
    t[rson].lazy += t[rt].lazy, t[rson].lazy %= mod;
    t[rt].lazy = 0;
  }
  void build(int rt, int l, int r) {
    t[rt].l = l, t[rt].r = r;
    if (l == r) return;
    int mid = (l + r) >> 1;
    build(lson, l, mid), build(rson, mid + 1, r);
    pushup(rt);
  }
  void update(int rt, int l, int r, int x) {
    if (l <= t[rt].l && t[rt].r <= r) {
      t[rt].sum += (t[rt].r - t[rt].l + 1) * x, t[rt].sum %= mod;
      t[rt].lazy += x, t[rt].lazy %= mod;
      return;
    }
    if (t[rt].lazy) pushdown(rt);
    int mid = (t[rt].l + t[rt].r) >> 1;
    if (l <= mid) update(lson, l, r, x);
    if (r > mid) update(rson, l, r, x);
    pushup(rt);
  }
  int query(int rt, int l, int r) {
    if (l <= t[rt].l && t[rt].r <= r) return t[rt].sum;
    if (t[rt].lazy) pushdown(rt);
    int mid = (t[rt].l + t[rt].r) >> 1, ans = 0;
    if (l <= mid) ans += query(lson, l, r);
    if (r > mid) ans += query(rson, l, r);
    return ans % mod;
  }
}

void prepare(int x, int f) {
  fa[x] = f, siz[x] = 1, dep[x] = dep[f] + 1;
  for (int i = head[x]; i; i = e[i].nxt) {
    int to = e[i].to;
    if (to == f) continue;
    prepare(to, x), siz[x] += siz[to];
    if (siz[to] > siz[son[x]]) son[x] = to;
  }
}

void dfs(int x, int tp) {
  top[x] = tp, dfn[x] = ++tot;
  if (son[x]) dfs(son[x], tp);
  for (int i = head[x]; i; i = e[i].nxt) {
    int to = e[i].to;
    if (to == fa[x] || to == son[x]) continue;
    dfs(to, to);
  }
}

inline void add_val(int x, int y) {
  while (top[x] != top[y]) {
    if (dep[top[x]] < dep[top[y]]) swap(x, y);
    Seg::update(1, dfn[top[x]], dfn[x], 1);
    x = fa[top[x]];
  }
  if (dep[x] > dep[y]) swap(x, y);
  Seg::update(1, dfn[x], dfn[y], 1);
  return;
} 

inline int ask_val(int x, int y) {
  int ans = 0;
  while (top[x] != top[y]) {
    if (dep[top[x]] < dep[top[y]]) swap(x, y);
    ans += Seg::query(1, dfn[top[x]], dfn[x]);
    x = fa[top[x]];
  }
  if (dep[x] > dep[y]) swap(x, y);
  ans += Seg::query(1, dfn[x], dfn[y]);
  return ans;
}

int main() {
  n = read(), q = read();
  for (int i = 2; i <= n; i++) {
    fa[i] = read() + 1;
    add(fa[i], i), add(i, fa[i]);
  }
  prepare(1, 0), dfs(1, 1), Seg::build(1, 1, n);
  for (int i = 1; i <= q; i++) {
    int l = read() + 1, r = read() + 1, z = read() + 1;
    a[++qcnt] = (ques) {l - 1, z, i, 0};
    a[++qcnt] = (ques) {r, z, i, 1};
  }
  sort(a + 1, a + 1 + qcnt, cmp);
  int now = 1;
  for (int i = 1; i <= qcnt; i++) {
    while (now <= a[i].r) add_val(1, now++);
    if (a[i].flag == 1) ans[a[i].num] += ask_val(1, a[i].z);
    else ans[a[i].num] -= ask_val(1, a[i].z);
    ans[a[i].num] += mod, ans[a[i].num] %= mod;
  }
  for (int i = 1; i <= q; i++) cout << ans[i] << '\n';
  return 0;
}