SP10707 COT2 - Count on a tree II
阿新 • • 發佈:2022-04-07
\(\text{Solution}\)
統計樹上 \(x\) 到 \(y\) 路徑不同數的種類數
可以樹上莫隊
離線的樹上莫隊就是把樹用尤拉序拍下來,然後和序列上的莫隊一樣即可
\(\text{Code}\)
#include <cstdio> #include <algorithm> #include <cmath> #define RE register #define IN inline using namespace std; const int N = 1e5 + 5; int n, m, a[N], b[N], h[N], tot, dfc, rev[N], st[N], ed[N], ans[N]; int fa[N], dep[N], siz[N], son[N], top[N], bl[N], Ans, buc[N], used[N]; struct edge{int to, nxt;}e[N * 2]; IN void add(int x, int y){e[++tot] = edge{y, h[x]}, h[x] = tot;} struct node{int l, r, id, z;}Q[N]; IN bool cmp(node a, node b){return ((bl[a.l] ^ bl[b.l]) ? a.l < b.l : ((bl[a.l] & 1) ? a.r < b.r : a.r > b.r));} void dfs1(int x, int f) { st[x] = ++dfc, rev[dfc] = x, fa[x] = f, dep[x] = dep[f] + 1, siz[x] = 1; for(RE int i = h[x]; i; i = e[i].nxt) { int v = e[i].to; if (v == f) continue; dfs1(v, x), siz[x] += siz[v]; if (siz[v] > siz[son[x]]) son[x] = v; } ed[x] = ++dfc, rev[dfc] = x; } void dfs2(int x, int t) { top[x] = t; if (son[x]) dfs2(son[x], t); for(RE int i = h[x]; i; i = e[i].nxt) { int v = e[i].to; if (v == fa[x] || v == son[x]) continue; dfs2(v, v); } } IN int LCA(int x, int y) { while (top[x] ^ top[y]) { if (dep[top[x]] > dep[top[y]]) x = fa[top[x]]; else y = fa[top[y]]; } return (dep[x] < dep[y] ? x : y); } IN void Del(int x){--buc[a[x]]; if (!buc[a[x]]) --Ans;} IN void Add(int x){if (!buc[a[x]]) ++Ans; ++buc[a[x]];} IN void update(int x){(used[x] ? Del(x) : Add(x)), used[x] ^= 1;} void GetQ() { for(RE int i = 1, x, y; i <= m; i++) { scanf("%d%d", &x, &y); if (st[x] > st[y]) swap(x, y); int lca = LCA(x, y); if (lca == x) Q[i] = node{st[x], st[y], i}; else Q[i] = node{ed[x], st[y], i, lca}; } } void solve() { GetQ(), sort(Q + 1, Q + m + 1, cmp); int l = 1, r = 0; for(RE int i = 1; i <= m; i++) { while (l < Q[i].l) update(rev[l++]); while (l > Q[i].l) update(rev[--l]); while (r < Q[i].r) update(rev[++r]); while (r > Q[i].r) update(rev[r--]); if (Q[i].z) update(Q[i].z); ans[Q[i].id] = Ans; if (Q[i].z) update(Q[i].z); } } int main() { scanf("%d%d", &n, &m); for(RE int i = 1; i <= n; i++) scanf("%d", &a[i]), b[i] = a[i]; sort(b + 1, b + n + 1); int len = unique(b + 1, b + n + 1) - b - 1; for(RE int i = 1; i <= n; i++) a[i] = lower_bound(b + 1, b + len + 1, a[i]) - b; for(RE int i = 1, x, y; i < n; i++) scanf("%d%d", &x, &y), add(x, y), add(y, x); dfs1(1, 0), dfs2(1, 1); int block = sqrt(n) + 1; for(RE int i = 1; i <= dfc; i++) bl[i] = i / block + 1; solve(); for(RE int i = 1; i <= m; i++) printf("%d\n", ans[i]); }