【結論】【樹(LCA)】NKOJ3815 樹上的詢問
阿新 • • 發佈:2019-01-02
NKOJ3815 樹上的詢問
時間限制 : - MS 空間限制 : 265536 KB
評測說明 : 1000ms
問題描述
現有一棵 n 個節點的樹,樹上每條邊的長度均為 1。給出 m 個詢問,每次詢問兩個節 點 a,b,求樹上到 a,b兩個點距離相同的節點數量。
輸入格式
第一個整數 n,表示樹有 n 個點。
接下來 n-1 行每行兩整數 a,b,表示從 a 到 b 有一條邊。
接下來一行一個整數 m,表示有 m 個詢問。
接下來 m 行每行兩整數 x,y,詢問到 x 和 y 距離相同的點的數量。
輸出格式
共 m 行,每行一個整數表示詢問的答案。
樣例輸入
7
1 2
1 3
2 4
2 5
3 6
3 7
3
1 2
4 5
2 3
樣例輸出
0
5
1
提示
【資料規模】
30%的資料,滿足 n≤50,m≤50 對於
60%的資料,滿足 n≤1000,m≤1000 對於
100%的資料,滿足 n≤100000,m≤100000
觀察發現:
#include<cstdio>
#include<iostream>
#include<cmath>
#include<cstdlib>
using namespace std;
const int need=100003;
const double lg2=log(2);
int n,m;
//.............................................................
inline void in_(int &d)
{
char t=getchar();
while(t<'0'||t>'9') t=getchar();
for(d=0;!(t<'0'||t>'9');t=getchar()) d=(d<<1)+(d<<3)+t-'0';
}
//.............................................................
int fi[need],la[need<<1],en[need<<1];
int tot=0;
void add(int a,int b)
{
tot++;
la[tot]=fi[a];
en[tot]=b;
fi[a]=tot;
}
//.............................................................
int dep[need];
int d=0;
bool vis[need];
int fa[need][22],son[need];
void dfs(int s)
{
vis[s]=true;
int k=ceil(log(dep[s])/lg2);
d=max(d,k);
for(int i=1;i<=k;i++) fa[s][i]=fa[fa[s][i-1]][i-1];
for(int t=fi[s],y;t;t=la[t])
{
y=en[t];
if(vis[y]) continue;
dep[y]=dep[s]+1;
fa[y][0]=s;
dfs(y);
son[s]+=son[y]+1;
}
}
int go_up(int v,int p)
{
for(int i=0;i<=d;i++) if((p>>i)&1) v=fa[v][i];
return v;
}
int lca(int v,int u)
{
if(dep[u]>dep[v]) swap(u,v);
v=go_up(v,dep[v]-dep[u]);
if(u==v) return v;
for(int i=d;i>=0;i--) if(fa[u][i]!=fa[v][i]) v=fa[v][i],u=fa[u][i];
return fa[u][0];
}
int getdis(int a,int b)
{
int c=lca(a,b);
return dep[a]+dep[b]-dep[c]-dep[c];
}
int getdis(int a,int b,int c)
{
return dep[a]+dep[b]-dep[c]-dep[c];
}
//.............................................................
int main()
{
scanf("%d",&n);
for(int i=1,a,b;i<n;i++)
{
in_(a),in_(b);//scanf("%d%d",&a,&b);
add(a,b),add(b,a);
}
dfs(1);
scanf("%d",&m);
for(int i=1,a,b,c,a1,ab,ac,bc,ans;i<=m;i++)
{
in_(a),in_(b);//scanf("%d%d",&a,&b);
c=lca(a,b);
ab=getdis(a,b,c);
if(ab&1) ans=0;
else
{
if(a==b) ans=n;
else
{
ac=getdis(a,c);
bc=getdis(b,c);
if(ac==bc)
{
a=go_up(a,dep[a]-dep[c]-1);
b=go_up(b,dep[b]-dep[c]-1);
ans=n-(son[a]+1)-(son[b]+1);
}
else
{
if(dep[b]>dep[a]) swap(a,b);
a=go_up(a,ab/2-1);
a1=go_up(a,1);
ans=son[a1]+1-(son[a]+1);
}
}
}
printf("%d\n",ans);
}
}