線段樹合併 csu1811 Tree Intersection
阿新 • • 發佈:2019-01-07
題意:對於每條邊,把這條邊刪了,樹分成了兩個集合,求這兩個集合中共同的顏色數量。
對於節點u,就看u節點的子樹中,有多少種顏色沒有到達這種顏色的上限,就是對u所對應的邊的答案
方法一:我們用線段樹合併來維護,程式碼寫起來比較麻煩。
之所以可以用線段樹合併,是因為有一個結論:如果初始的時候只有一個葉子的線段樹有n個,那麼最後合併成一個線段樹,總複雜度只有O(nlogn)
#include <map> #include <set> #include <cmath> #include <ctime> #include <stack> #include <queue> #include <cstdio> #include <cctype> #include <bitset> #include <string> #include <vector> #include <cstring> #include <iostream> #include <algorithm> #include <functional> #define fuck(x) cout<<"["<<x<<"]"; #define FIN freopen("input.txt","r",stdin); #define FOUT freopen("output.txt","w+",stdout); //#pragma comment(linker, "/STACK:102400000,102400000") using namespace std; typedef long long LL; typedef pair<int, int>PII; const int MX = 1e5 + 5; const int MD = 2e6 + 5; struct Edge { int v, nxt; } E[MX << 1]; int Head[MX], erear; void edge_init() { erear = 0; memset(Head, -1, sizeof(Head)); } void edge_add(int u, int v) { E[erear].v = v; E[erear].nxt = Head[u]; Head[u] = erear++; } struct Node { int l, r; int val, sum; } A[MD]; int n, sz, id[MX], C[MX], root[MX]; int ans[MX], cnt[MX]; void push_up(int rt) { A[rt].sum = A[A[rt].l].sum + A[A[rt].r].sum; } int build(int c, int l, int r) { int rt = ++sz; A[rt].l = A[rt].r = 0; A[rt].sum = 0; if(l == r) { A[rt].val = 1; A[rt].sum = (A[rt].val != cnt[l]); return rt; } int m = (l + r) >> 1; if(c <= m) A[rt].l = build(c, l, m); else A[rt].r = build(c, m + 1, r); push_up(rt); return rt; } void merge(int &rt1, int rt2, int l, int r) { if(!rt1 || !rt2) { if(!rt1) rt1 = rt2; return; } if(l == r) { A[rt1].val += A[rt2].val; A[rt1].sum = (A[rt1].val != cnt[l]); return; } int m = (l + r) >> 1; merge(A[rt1].l, A[rt2].l, l, m); merge(A[rt1].r, A[rt2].r, m + 1, r); push_up(rt1); } void DFS(int u, int f, int e) { int doc = 0; root[u] = build(C[u], 1, n); for(int i = Head[u]; ~i; i = E[i].nxt) { int v = E[i].v; if(v == f) continue; DFS(v, u, i); merge(root[u], root[v], 1, n); } if(u != 1) { int id = e / 2 + 1; ans[id] = A[root[u]].sum; } } int main() { // FIN; while(~scanf("%d", &n)) { edge_init(); sz = 0; memset(cnt, 0, sizeof(cnt)); for(int i = 1; i <= n; i++) { scanf("%d", &C[i]); cnt[C[i]]++; } for(int i = 1; i <= n - 1; i++) { int u, v; scanf("%d%d", &u, &v); edge_add(u, v); edge_add(v, u); } DFS(1, -1, -1); for(int i = 1; i <= n - 1; i++) { printf("%d\n", ans[i]); } } return 0; }