1. 程式人生 > >線段樹合併 csu1811 Tree Intersection

線段樹合併 csu1811 Tree Intersection

題意:對於每條邊,把這條邊刪了,樹分成了兩個集合,求這兩個集合中共同的顏色數量。

對於節點u,就看u節點的子樹中,有多少種顏色沒有到達這種顏色的上限,就是對u所對應的邊的答案

方法一:我們用線段樹合併來維護,程式碼寫起來比較麻煩。

之所以可以用線段樹合併,是因為有一個結論:如果初始的時候只有一個葉子的線段樹有n個,那麼最後合併成一個線段樹,總複雜度只有O(nlogn)

#include <map>
#include <set>
#include <cmath>
#include <ctime>
#include <stack>
#include <queue>
#include <cstdio>
#include <cctype>
#include <bitset>
#include <string>
#include <vector>
#include <cstring>
#include <iostream>
#include <algorithm>
#include <functional>
#define fuck(x) cout<<"["<<x<<"]";
#define FIN freopen("input.txt","r",stdin);
#define FOUT freopen("output.txt","w+",stdout);
//#pragma comment(linker, "/STACK:102400000,102400000")
using namespace std;
typedef long long LL;
typedef pair<int, int>PII;
 
const int MX = 1e5 + 5;
const int MD = 2e6 + 5;
 
struct Edge {
    int v, nxt;
} E[MX << 1];
int Head[MX], erear;
void edge_init() {
    erear = 0;
    memset(Head, -1, sizeof(Head));
}
void edge_add(int u, int v) {
    E[erear].v = v;
    E[erear].nxt = Head[u];
    Head[u] = erear++;
}
 
struct Node {
    int l, r;
    int val, sum;
} A[MD];
int n, sz, id[MX], C[MX], root[MX];
int ans[MX], cnt[MX];
 
void push_up(int rt) {
    A[rt].sum = A[A[rt].l].sum + A[A[rt].r].sum;
}
int build(int c, int l, int r) {
    int rt = ++sz;
    A[rt].l = A[rt].r = 0;
    A[rt].sum = 0;
    if(l == r) {
        A[rt].val = 1;
        A[rt].sum = (A[rt].val != cnt[l]);
        return rt;
    }
    int m = (l + r) >> 1;
    if(c <= m) A[rt].l = build(c, l, m);
    else A[rt].r = build(c, m + 1, r);
    push_up(rt);
    return rt;
}
void merge(int &rt1, int rt2, int l, int r) {
    if(!rt1 || !rt2) {
        if(!rt1) rt1 = rt2;
        return;
    }
    if(l == r) {
        A[rt1].val += A[rt2].val;
        A[rt1].sum = (A[rt1].val != cnt[l]);
        return;
    }
    int m = (l + r) >> 1;
    merge(A[rt1].l, A[rt2].l, l, m);
    merge(A[rt1].r, A[rt2].r, m + 1, r);
    push_up(rt1);
}
void DFS(int u, int f, int e) {
    int doc = 0;
    root[u] = build(C[u], 1, n);
    for(int i = Head[u]; ~i; i = E[i].nxt) {
        int v = E[i].v;
        if(v == f) continue;
        DFS(v, u, i);
        merge(root[u], root[v], 1, n);
    }
    if(u != 1) {
        int id = e / 2 + 1;
        ans[id] = A[root[u]].sum;
    }
}
 
int main() {
    // FIN;
    while(~scanf("%d", &n)) {
        edge_init(); sz = 0;
        memset(cnt, 0, sizeof(cnt));
 
        for(int i = 1; i <= n; i++) {
            scanf("%d", &C[i]);
            cnt[C[i]]++;
        }
        for(int i = 1; i <= n - 1; i++) {
            int u, v;
            scanf("%d%d", &u, &v);
            edge_add(u, v);
            edge_add(v, u);
        }
        DFS(1, -1, -1);
        for(int i = 1; i <= n - 1; i++) {
            printf("%d\n", ans[i]);
        }
    }
    return 0;
}