1. 程式人生 > 其它 >SP10707 COT2 - Count on a tree II

SP10707 COT2 - Count on a tree II

\(\text{Solution}\)

統計樹上 \(x\)\(y\) 路徑不同數的種類數
可以樹上莫隊
離線的樹上莫隊就是把樹用尤拉序拍下來,然後和序列上的莫隊一樣即可

\(\text{Code}\)

#include <cstdio>
#include <algorithm>
#include <cmath>
#define RE register
#define IN inline
using namespace std;

const int N = 1e5 + 5;
int n, m, a[N], b[N], h[N], tot, dfc, rev[N], st[N], ed[N], ans[N];
int fa[N], dep[N], siz[N], son[N], top[N], bl[N], Ans, buc[N], used[N];
struct edge{int to, nxt;}e[N * 2];
IN void add(int x, int y){e[++tot] = edge{y, h[x]}, h[x] = tot;}
struct node{int l, r, id, z;}Q[N];
IN bool cmp(node a, node b){return ((bl[a.l] ^ bl[b.l]) ? a.l < b.l : ((bl[a.l] & 1) ? a.r < b.r : a.r > b.r));}

void dfs1(int x, int f)
{
	st[x] = ++dfc, rev[dfc] = x, fa[x] = f, dep[x] = dep[f] + 1, siz[x] = 1;
	for(RE int i = h[x]; i; i = e[i].nxt)
	{
		int v = e[i].to;
		if (v == f) continue;
		dfs1(v, x), siz[x] += siz[v];
		if (siz[v] > siz[son[x]]) son[x] = v;
	}
	ed[x] = ++dfc, rev[dfc] = x;
}
void dfs2(int x, int t)
{
	top[x] = t;
	if (son[x]) dfs2(son[x], t);
	for(RE int i = h[x]; i; i = e[i].nxt)
	{
		int v = e[i].to;
		if (v == fa[x] || v == son[x]) continue;
		dfs2(v, v);
	}
}
IN int LCA(int x, int y)
{
	while (top[x] ^ top[y])
	{
		if (dep[top[x]] > dep[top[y]]) x = fa[top[x]];
		else y = fa[top[y]];
	}
	return (dep[x] < dep[y] ? x : y);
}

IN void Del(int x){--buc[a[x]]; if (!buc[a[x]]) --Ans;}
IN void Add(int x){if (!buc[a[x]]) ++Ans; ++buc[a[x]];}
IN void update(int x){(used[x] ? Del(x) : Add(x)), used[x] ^= 1;}
void GetQ()
{
	for(RE int i = 1, x, y; i <= m; i++)
	{
		scanf("%d%d", &x, &y);
		if (st[x] > st[y]) swap(x, y);
		int lca = LCA(x, y);
		if (lca == x) Q[i] = node{st[x], st[y], i};
		else Q[i] = node{ed[x], st[y], i, lca};
	}
}
void solve()
{
	GetQ(), sort(Q + 1, Q + m + 1, cmp);
	int l = 1, r = 0;
	for(RE int i = 1; i <= m; i++)
	{
		while (l < Q[i].l) update(rev[l++]);
		while (l > Q[i].l) update(rev[--l]);
		while (r < Q[i].r) update(rev[++r]);
		while (r > Q[i].r) update(rev[r--]);
		if (Q[i].z) update(Q[i].z);
		ans[Q[i].id] = Ans;
		if (Q[i].z) update(Q[i].z);
	}
}

int main() 
{
	scanf("%d%d", &n, &m);
	for(RE int i = 1; i <= n; i++) scanf("%d", &a[i]), b[i] = a[i];
	sort(b + 1, b + n + 1); int len = unique(b + 1, b + n + 1) - b - 1;
	for(RE int i = 1; i <= n; i++) a[i] = lower_bound(b + 1, b + len + 1, a[i]) - b;
	for(RE int i = 1, x, y; i < n; i++) scanf("%d%d", &x, &y), add(x, y), add(y, x);
	dfs1(1, 0), dfs2(1, 1); int block = sqrt(n) + 1;
	for(RE int i = 1; i <= dfc; i++) bl[i] = i / block + 1;
	solve();
	for(RE int i = 1; i <= m; i++) printf("%d\n", ans[i]);
}