1. 程式人生 > >【結論】【樹(LCA)】NKOJ3815 樹上的詢問

【結論】【樹(LCA)】NKOJ3815 樹上的詢問

NKOJ3815 樹上的詢問
時間限制 : - MS 空間限制 : 265536 KB
評測說明 : 1000ms

問題描述
現有一棵 n 個節點的樹,樹上每條邊的長度均為 1。給出 m 個詢問,每次詢問兩個節 點 a,b,求樹上到 a,b兩個點距離相同的節點數量。

輸入格式
第一個整數 n,表示樹有 n 個點。
接下來 n-1 行每行兩整數 a,b,表示從 a 到 b 有一條邊。
接下來一行一個整數 m,表示有 m 個詢問。
接下來 m 行每行兩整數 x,y,詢問到 x 和 y 距離相同的點的數量。

輸出格式
共 m 行,每行一個整數表示詢問的答案。

樣例輸入


7
1 2
1 3
2 4
2 5
3 6
3 7
3
1 2
4 5
2 3

樣例輸出
0
5
1

提示
【資料規模】
30%的資料,滿足 n≤50,m≤50 對於
60%的資料,滿足 n≤1000,m≤1000 對於
100%的資料,滿足 n≤100000,m≤100000

觀察發現:
這裡寫圖片描述

#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstdlib>
using namespace std;
const int need=100003;
const double
lg2=log(2); int n,m; //............................................................. inline void in_(int &d) { char t=getchar(); while(t<'0'||t>'9') t=getchar(); for(d=0;!(t<'0'||t>'9');t=getchar()) d=(d<<1)+(d<<3)+t-'0'; } //.............................................................
int fi[need],la[need<<1],en[need<<1]; int tot=0; void add(int a,int b) { tot++; la[tot]=fi[a]; en[tot]=b; fi[a]=tot; } //............................................................. int dep[need]; int d=0; bool vis[need]; int fa[need][22],son[need]; void dfs(int s) { vis[s]=true; int k=ceil(log(dep[s])/lg2); d=max(d,k); for(int i=1;i<=k;i++) fa[s][i]=fa[fa[s][i-1]][i-1]; for(int t=fi[s],y;t;t=la[t]) { y=en[t]; if(vis[y]) continue; dep[y]=dep[s]+1; fa[y][0]=s; dfs(y); son[s]+=son[y]+1; } } int go_up(int v,int p) { for(int i=0;i<=d;i++) if((p>>i)&1) v=fa[v][i]; return v; } int lca(int v,int u) { if(dep[u]>dep[v]) swap(u,v); v=go_up(v,dep[v]-dep[u]); if(u==v) return v; for(int i=d;i>=0;i--) if(fa[u][i]!=fa[v][i]) v=fa[v][i],u=fa[u][i]; return fa[u][0]; } int getdis(int a,int b) { int c=lca(a,b); return dep[a]+dep[b]-dep[c]-dep[c]; } int getdis(int a,int b,int c) { return dep[a]+dep[b]-dep[c]-dep[c]; } //............................................................. int main() { scanf("%d",&n); for(int i=1,a,b;i<n;i++) { in_(a),in_(b);//scanf("%d%d",&a,&b); add(a,b),add(b,a); } dfs(1); scanf("%d",&m); for(int i=1,a,b,c,a1,ab,ac,bc,ans;i<=m;i++) { in_(a),in_(b);//scanf("%d%d",&a,&b); c=lca(a,b); ab=getdis(a,b,c); if(ab&1) ans=0; else { if(a==b) ans=n; else { ac=getdis(a,c); bc=getdis(b,c); if(ac==bc) { a=go_up(a,dep[a]-dep[c]-1); b=go_up(b,dep[b]-dep[c]-1); ans=n-(son[a]+1)-(son[b]+1); } else { if(dep[b]>dep[a]) swap(a,b); a=go_up(a,ab/2-1); a1=go_up(a,1); ans=son[a1]+1-(son[a]+1); } } } printf("%d\n",ans); } }