Algorithm---LCA(倍增演算法)
阿新 • • 發佈:2019-02-10
deep[i] 表示 i節點的深度, fa[i,j]表示 i 的 2^j (即2的j次方) 倍祖先,那麼fa[i , 0]即為節點i 的父親,然後就有一個遞推式子:
fa[i,j]= fa [ fa [i,j-1] , j-1 ]
可以這樣理解:
設tmp = fa [i, j - 1] ,tmp2 = fa [tmp, j - 1 ] ,即tmp 是i 的第2 ^ (j - 1) 倍祖先,tmp2 是tmp 的第2 ^ (j - 1) 倍祖先 , 所以tmp2 是i 的第 2 ^ (j - 1) + 2 ^ (j - 1) = 2^ j 倍祖先,注意:這裡的“倍”可不能理解為倍數的意思,而是距離節點 i有多遠的意思,節點i的第2
^ j 倍祖先表示的節點u滿足deep[ u ] - deep[ i ] = 2 ^ j。這樣子一個O(NlogN)的預處理求出每個節點的 2^k 的祖先
然後對於每一個詢問的點對a, b的最近公共祖先就是:
先判斷是否 d[x]< d[y] ,如果是的話就交換一下(保證 x 的深度大於 y 的深度), 然後把 x 調到與 y 同深度, 同深度以後再把a, b 同時往上調,調到有一個最小的 j 滿足fa [x,j] != fa [y,j] (x,y是在不斷更新的), 最後再把(x,y)往上調(x=p[x,0], y=p[y,0]) ,一個一個向上調直到x = y, 這時 x或y 就是他們的最近公共祖先。
Ps:如果還是不明白,就手動模擬一棵節點數為9的樹(如下圖所示),很快就會理解的。還有我不得不感嘆一句 :二進位制真的很神奇!!
#include<iostream> #include<cstring> #include<algorithm> #include<string> #include<cmath> #include<vector> #include<cstdio> #define mem(a , b) memset(a , b , sizeof(a)) using namespace std ; inline void RD(int &a) { a = 0 ; char t ; do { t = getchar() ; } while (t < '0' || t > '9') ; a = t - '0' ; while ((t = getchar()) >= '0' && t <= '9') { a = a * 10 + t - '0' ; } } inline void OT(int a) { if(a >= 10) { OT(a / 10) ; } putchar(a % 10 + '0') ; } const int MAXN = 10005 ; const int M = 30 ; vector<int> G[MAXN] ; bool vis[MAXN] ; int deep[MAXN] ; int fa[MAXN][M] ; int n ; int root ; void chu() { mem(vis , 0) ; mem(deep , 0) ; mem(fa , 0) ; int i ; for(i = 0 ; i <= n ; i ++) G[i].clear() ; } void dfs(int u) { vis[u] = true ; int i ; for(i = 0 ; i < G[u].size() ; i ++) { int v = G[u][i] ; if(!vis[v]) { deep[v] = deep[u] + 1 ; dfs(v) ; } } } void bz() // 倍增祖先 { int i , j ; for(j = 1 ; j < M ; j ++) { for(i = 1 ; i <= n ; i ++) { fa[i][j] = fa[ fa[i][j - 1] ][j - 1] ; } } } void swap(int &x , int &y) { int tmp = x ; x = y ; y = tmp ; } int LCA(int u , int v) { if(deep[u] < deep[v]) swap(u , v) ; int d = deep[u] - deep[v] ; int i ; for(i = 0 ; i < M ; i ++) { if( (1 << i) & d ) // 注意此處,動手模擬一下,就會明白的 { u = fa[u][i] ; } } if(u == v) return u ; for(i = M - 1 ; i >= 0 ; i --) { if(fa[u][i] != fa[v][i]) { u = fa[u][i] ; v = fa[v][i] ; } } u = fa[u][0] ; return u ; } void init() { scanf("%d" , &n) ; chu() ; int i ; for(i = 0 ; i < n - 1 ; i ++) { int a , b ; scanf("%d%d" , &a , &b) ; G[a].push_back(b) ; fa[b][0] = a ; if(fa[a][0] == 0) { root = a ; } } deep[root] = 1 ; dfs(root) ; bz() ; int u , v ; scanf("%d%d" , &u , &v) ; printf("%d\n", LCA(u , v)) ; } int main() { int T ; scanf("%d" , &T) ; while (T --) { init() ; } return 0 ; }