1. 程式人生 > 其它 >Snow的追尋(線段樹)(LCA)

Snow的追尋(線段樹)(LCA)

給你一棵樹,每次規定兩個子樹不能到,問你樹上的最長路徑長度。

Snow的追尋

題目大意

給你一棵樹,每次規定兩個子樹不能到,問你樹上的最長路徑長度。

思路

看到有關子樹,考慮用 dfs 序來搞。
而且一般這種子樹的操作會用到線段樹?
考慮用線段樹維護,維護 \(l\sim r\) 區間的點能形成的最長路徑。

這樣子的話,我們可以把題目要求不能有兩個子樹裡面的點得到剩下的點,用線段樹拿出那三段,然後再用線段樹的合併方法合併起來,最後得到的值就是答案。
然後考慮如何合併。
考慮線段樹維護這一條路徑的長度,以及兩段的兩個點。

那你合併的時候,就兩條路徑四個點,不難想到新的路徑的兩個端點一定是這四個之中的兩個。
那我們可以直接暴力看沒兩個點的匹配情況(求兩點之間路徑長度用 LCA 求),然後選最大的那個。

然後就可以啦。

程式碼

#include<cstdio>
#include<iostream>
#include<algorithm>

using namespace std;

const int N = 100005;
struct node {
	int to, nxt;
}e[N << 2];
int n, q, x, y, le[N], KK, up[N], ans, tmp;
int fa[N][21], deg[N], dfn[N], ed[N], dy[N];

void add(int x, int y) {
	e[++KK] = (node){y, le[x]}; le[x] = KK;
	e[++KK] = (node){x, le[y]}; le[y] = KK;
}

void dfs(int now, int father) {
	deg[now] = deg[father] + 1;
	fa[now][0] = father;
	dfn[now] = ++tmp;
	dy[tmp] = now;
	for (int i = le[now]; i; i = e[i].nxt)
		if (e[i].to != father) {
			dfs(e[i].to, now);
		}
	ed[now] = tmp;
}

int LCA(int x, int y) {//LCA 求路徑長度
	if (deg[x] < deg[y]) swap(x, y);
	for (int i = 20; i >= 0; i--)
		if (deg[fa[x][i]] >= deg[y]) x = fa[x][i];
	if (x == y) return x;
	for (int i = 20; i >= 0; i--)
		if (fa[x][i] != fa[y][i])
			x = fa[x][i], y = fa[y][i];
	return fa[x][0];
}

int get_dis(int x, int y) {
	if (!x || !y) return 0;
	int z = LCA(x, y);
	return deg[x] + deg[y] - 2 * deg[z];
}

struct XDtree {//線段樹
	struct node {
		int val, fir, sec;
	}a[N << 2], ans;
	
	void merge(node &x, node y, node z) {
		int a, b, c, d, e;
		a = get_dis(y.fir, z.fir);
		b = get_dis(y.fir, z.sec);
		c = get_dis(y.sec, z.fir);
		d = get_dis(y.sec, z.sec);
		e = max(max(a, b), max(c, d));
		x.val = e;//四個點裡面任選兩個匹配得到最長路徑
		if (e == a) x.fir = y.fir, x.sec = z.fir;
		if (e == b) x.fir = y.fir, x.sec = z.sec;
		if (e == c) x.fir = y.sec, x.sec = z.fir;
		if (e == d) x.fir = y.sec, x.sec = z.sec;
		if (x.val < y.val) x.val = y.val, x.fir = y.fir, x.sec = y.sec;
		if (x.val < z.val) x.val = z.val, x.fir = z.fir, x.sec = z.sec;
		if (!x.val) x.fir = x.sec = 0;
	}
	
	void build(int now, int l, int r) {
		if (l == r) {
			a[now].fir = a[now].sec = dy[l];
			a[now].val = 0; return ;
		}
		int mid = (l + r) >> 1;
		build(now << 1, l, mid);
		build(now << 1 | 1, mid + 1, r);
		merge(a[now], a[now << 1], a[now << 1 | 1]); 
	}
	
	void find(int now, int l, int r, int L, int R) {
		if (L > R) return ;
		if (L <= l && r <= R) {
			merge(ans, ans, a[now]);
			return ;
		}
		int mid = (l + r) >> 1;
		if (L <= mid) find(now << 1, l, mid, L, R);
		if (mid < R) find(now << 1 | 1, mid + 1, r, L, R);
	}
}T;

int main() {
//	freopen("snow.in", "r", stdin);
//	freopen("snow.out", "w", stdout);
	
	scanf("%d %d", &n, &q);
	for (int i = 1; i < n; i++) {
		scanf("%d %d", &x, &y);
		add(x, y);
	}
	
	dfs(1, 0);
	for (int i = 1; i <= 20; i++)
		for (int j = 1; j <= n; j++)
			fa[j][i] = fa[fa[j][i - 1]][i - 1];
	T.build(1, 1, n);
	
	while (q--) {
		scanf("%d %d", &x, &y);
		if (dfn[y] < dfn[x]) swap(x, y);
		T.ans.val = T.ans.fir = T.ans.sec = 0;
		T.find(1, 1, n, 1, dfn[x] - 1);//分成三段合併入答案
		T.find(1, 1, n, ed[x] + 1, dfn[y] - 1);
		if (ed[y] >= ed[x]) T.find(1, 1, n, ed[y] + 1, n);
			else T.find(1, 1, n, ed[x] + 1, n);
		printf("%d\n", T.ans.val);
	}
	
	fclose(stdin);
	fclose(stdout);
	
	return 0;
}