1. 程式人生 > 實用技巧 >LCA最近公共祖先演算法

LCA最近公共祖先演算法

LCA最近公共祖先

​ LCA是指在有根樹中,找出某兩個節點\(u\)\(v\)的最近公共祖先,即找到一個節點,同時是\(u\)\(v\)的公共祖先,並且深度儘可能大

模板題目連結https://www.luogu.com.cn/problem/P3379

樸素演算法

​ 比如對於下面這樣一個樹,求LCA的過程,大體如下,首先我們先求出標號為4的節點的所有父節點,然後在對標號為5的節點不斷向上求父節點,判斷是否在4的父節點中,如果是就求出了就求出了公共祖先,如果樹的深度很大,時間複雜度就是\(O(n + m)\)

倍增法

​ 我們可以考慮,如果兩個節點同事向上跳,直到相遇,相遇的點就是他們的LCA。但是如果樹的深度很大,就需要跳很久,時間複雜度就是\(O(n*m)\)

。採用倍增法來優化。

​ 首先記錄下每個節點的父節點和各個祖先節點,使用一個\(f[N][30]\)陣列用來表示節點\(x\)的第\(i + 1\)位祖先,也就是說\(x\)的父親節點是\(f[x][0]\),這樣我們在更新的時候可以得到一個遞推式,\(f[x][i] = f[f[i][i-1]][i-1]\),就可以預處理處每個節點的祖先

​ 在向上跳的時候,首先讓\(x\)\(y\)處於同一層,讓深度更深的向上跳,然後兩個在一起跳,知道兩個節點有了同一個父節點。當然我們可以再兩個節點處於同一層的時候,判斷是否匯合如果匯合,就返回

時間複雜度

\(O(mlog(n))\)

實現程式碼

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

const int N = 500010, M = N * 2;

int n, m, root;
int h[N], e[M], ne[M], idx;
int depth[N], fa[N][26];
int q[N];

void add(int a, int b) {
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

void bfs(int root) {
    memset(depth, 0x3f, sizeof depth);
    depth[0] = 0, depth[root] = 1;
    int hh = 0, tt = 0;
    q[0] = root;
    while (hh <= tt) {
        int t = q[hh++];
        for (int i = h[t]; ~i; i = ne[i]) {
            int j = e[i];
            if (depth[j] > depth[t] + 1) {
                depth[j] = depth[t] + 1;
                q[++tt] = j;
                fa[j][0] = t;
                for (int k = 1; k <= 25; k++)
                    fa[j][k] = fa[fa[j][k - 1]][k - 1];
            }
        }
    }
}

int lca(int a, int b) {
    if (depth[a] < depth[b]) swap(a, b);
    for (int k = 25; k >= 0; k--)
        if (depth[fa[a][k]] >= depth[b])
            a = fa[a][k];
    if (a == b) return a;
    for (int k = 25; k >= 0; k--)
        if (fa[a][k] != fa[b][k]) {
            a = fa[a][k];
            b = fa[b][k];
        }
    return fa[a][0];
}

int main() {
    scanf("%d%d%d", &n, &m, &root);

    memset(h, -1, sizeof h);

    for (int i = 0; i < n - 1; i++) {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }

    bfs(root);

    while (m--) {
        int a, b;
        scanf("%d%d", &a, &b);
        int p = lca(a, b);
        printf("%d\n", p);
    }

    return 0;
}

Tarjan

​ 可以看出倍增的做法是強制線上演算法,必須針對每一個問題去單獨執行lca。Tarjan是強制離線演算法,每次將結果計算好,然後直接查詢即可。

​ tarjan演算法的流程如下:

  1. 從根節點開始
  2. 遍歷該點\(u\)的所有子節點\(v\),並標記這些子節點\(v\)已經被訪問過了
  3. 如果\(u\)還有子節點,就重複步驟2
  4. 合併\(v\)\(u\)
  5. 尋找與當前點\(u\)有詢問關係的點\(v\)
  6. 如果\(v\)已經被訪問過了,則可以確定\(u\)\(v\)的最近公共祖先為\(v\)被合併到父親節點\(a\)

實現程式碼

#include <iostream>
#include <cstring>
#include <cstdio>
#include <algorithm>
#include <vector>

using namespace std;

typedef pair<int, int> PII;

const int N = 500010, M = 2 * N;

int n, m, root;
int h[N], e[M], ne[M], idx;

int p[N];
int res[M];
int st[N];
int dist[N];
// first 存查詢的另外一個點,second存查詢編號
vector <PII> query[N];

void add(int a, int b) {
    e[idx] = b, ne[idx] = h[a], h[a] = idx++;
}

int find(int x) {
    if (x != p[x]) p[x] = find(p[x]);
    return p[x];
}

void tarjan(int u) {
    st[u] = 1;
    for (int i = h[u]; ~i; i = ne[i]) {
        int j = e[i];
        if (!st[j]) {
            tarjan(j);
            p[j] = u;
        }
    }

    for (auto item : query[u]) {
        int y = item.first, id = item.second;
        if (st[y] == 2) {
            int anc = find(y);
            res[id] = anc;
        }
    }

    st[u] = 2;
}


int main() {
    scanf("%d%d%d", &n, &m, &root);

    memset(h, -1, sizeof h);
    for (int i = 0; i < n - 1; i++) {
        int a, b;
        scanf("%d%d", &a, &b);
        add(a, b);
        add(b, a);
    }

    for (int i = 0; i < m; i++) {
        int a, b;
        scanf("%d%d", &a, &b);
        if (a != b) {
            query[a].push_back({b, i});
            query[b].push_back({a, i});
        }
    }

    for (int i = 1; i <= n; i++) p[i] = i;


    tarjan(root);

    for (int i = 0; i < m; i++) printf("%d\n", res[i]);
    return 0;
}