1. 程式人生 > 實用技巧 >JZOJ 6904. 【2020.11.28提高組模擬】T3 樹上詢問(query)

JZOJ 6904. 【2020.11.28提高組模擬】T3 樹上詢問(query)

題目

你有一棵 \(n\) 節點的樹 ,回答 \(m\) 個詢問,每次詢問給你兩個整數 \(l,r\) ,問存在多少個整數 \(k\) 使得從 \(l\) 沿著 \(l \to r\) 的簡單路徑走 \(k\) 步恰好到達 \(k\)

分析

考慮離線後按鏈記貢獻
\(l\)\(lca(l,r)\) 這段鏈上,可以計入貢獻的點 \(x\) 滿足 \(dep[l]-x=dep[x]\),稱為一類貢獻
\(dep[x]+x=dep[l]\), 因為已知 \(dep[l]\),所以直接開桶計算
\(lca(l,r)\)\(r\) 這段鏈上,可以計入貢獻的點 \(x\) 滿足 \(dep[lca]+(x-dep[l]-dep[lca])=dep[x]\)

,稱為二類貢獻
\(dep[x]-x=2\times dep[lca]-dep[l]\),同樣可以直接開另一個桶計算
因為 \(dfs\) 下來時桶記錄的是根到當前點的資訊,所以算貢獻的時候要減去 \(lca\) 處的假貢獻
\(lca\) 也可能成為需要貢獻,所以算二類貢獻的時候減去 \(father_{lca}\) 處的貢獻
具體細節體現在程式碼

\(Code\)

#include<cstdio>
#include<vector>
using namespace std;

const int N = 3e5 + 5;
int n, m, dep[N], d[2][2*N], fa[N], da[N], vis[N], l[N], r[N], lca[N], ans[N];
vector<int> e[N];
struct node1{int x, id;};
vector<node1> q1[N];
struct node2{int cs, ty, f, id;};
vector<node2> q2[N];

int find(int x){return fa[x] == x ? x : fa[x] = find(fa[x]);}
void dfs(int x, int dad)
{
	da[x] = dad, dep[x] = dep[dad] + 1;
	for(register int i = 0; i < e[x].size(); i++)
	{
		if (e[x][i] == dad) continue;
		dfs(e[x][i], x);
	}
}
void dfs1(int x, int dad)
{
	vis[x] = 1;
	for(register int i = 0; i < e[x].size(); i++)
	{
		if (e[x][i] == dad) continue;
		dfs1(e[x][i], x), fa[e[x][i]] = x;
	}
	for(register int i = 0; i < q1[x].size(); i++)
		if (vis[q1[x][i].x]) lca[q1[x][i].id] = find(q1[x][i].x);
}
void dfs2(int x, int dad)
{
	++d[0][dep[x] + x], ++d[1][dep[x] - x + n];
	for(register int i = 0; i < q2[x].size(); i++)
		ans[q2[x][i].id] += q2[x][i].f * d[q2[x][i].ty][q2[x][i].cs];
	for(register int i = 0; i < e[x].size(); i++)
	{
		if (e[x][i] == dad) continue;
		dfs2(e[x][i], x);
	}
	--d[0][dep[x] + x], --d[1][dep[x] - x + n];
}

int main()
{
	freopen("query.in" , "r" , stdin);
	freopen("query.out" , "w" , stdout);
	scanf("%d%d" , &n , &m);
	int x , y;
	for(register int i = 1; i < n; i++)
	{
		scanf("%d%d" , &x , &y);
		e[x].push_back(y), e[y].push_back(x);
	}
	for(register int i = 1; i <= m; i++)
	{
		scanf("%d%d" , &l[i], &r[i]);
		q1[l[i]].push_back(node1{r[i], i});
		q1[r[i]].push_back(node1{l[i], i});
	}
	for(register int i = 1; i <= n; i++) fa[i] = i;
	dfs(1, 0), dfs1(1, 0);
	for(register int i = 1; i <= m; i++)
	{
		q2[l[i]].push_back(node2{dep[l[i]], 0, 1, i}); 
		q2[lca[i]].push_back(node2{dep[l[i]], 0, -1, i});
		q2[r[i]].push_back(node2{2*dep[lca[i]]-dep[l[i]]+n, 1, 1, i});
		if (lca[i] > 1) 
			q2[da[lca[i]]].push_back(node2{2*dep[lca[i]]-dep[l[i]]+n, 1, -1, i});
	}
	dfs2(1, 0);
	for(register int i = 1; i <= m; i++) printf("%d\n" , ans[i]);
}