Codeforces 620E - New Year Tree
阿新 • • 發佈:2022-02-12
dfs序 + 狀態壓縮 + 線段樹
#include <bits/stdc++.h> #define LL long long using namespace std; const int N = 4e5 + 10, M = N * 2; int n, m; int w[N], h[N], e[M], ne[M], idx; int id[N], nw[N], sz[N], cnt; struct Node { int l, r; LL tag, state; }tr[N * 4]; void add(int a, int b) { e[idx] = b, ne[idx] = h[a], h[a] = idx ++; } void pushup(int u) { tr[u].state = tr[u << 1].state | tr[u << 1 | 1].state; } void pushdown(int u) { if (tr[u].tag) { tr[u << 1].tag = tr[u].tag; tr[u << 1 | 1].tag = tr[u].tag; tr[u << 1].state = tr[u].tag; tr[u << 1 | 1].state = tr[u].tag; tr[u].tag = 0; } } void build(int u, int l, int r) { if (l == r) { LL k = w[nw[l]] - 1; k = 1LL << k; tr[u] = {l, r, 0, k}; return; } tr[u] = {l, r}; int mid = l + r >> 1; build(u << 1, l, mid); build(u << 1 | 1, mid + 1, r); pushup(u); } LL query(int u, int l, int r) { if (tr[u].l >= l && tr[u].r <= r) return tr[u].state; pushdown(u); int mid = tr[u].l + tr[u].r >> 1; LL res = 0; if (l <= mid) res = query(u << 1, l, r); if (r > mid) res |= query(u << 1 | 1, l, r); return res; } void modify(int u, int l, int r, int x) { if (tr[u].l >= l && tr[u].r <= r) { tr[u].state = 1LL << (x - 1); tr[u].tag = 1LL << (x - 1); return; } pushdown(u); int mid = tr[u].l + tr[u].r >> 1; if (l <= mid) modify(u << 1, l, r, x); if (r > mid) modify(u << 1 | 1, l, r, x); pushup(u); } int main() { memset(h, -1, sizeof h); scanf("%d%d", &n, &m); for (int i = 1; i <= n; i ++ ) scanf("%d", &w[i]); for (int i = 1; i < n; i ++ ) { int a, b; scanf("%d%d", &a, &b); add(a, b), add(b, a); } function<void(int, int)> dfs = [&](int u, int father) { id[u] = ++ cnt, nw[cnt] = u, sz[u] = 1; for (int i = h[u]; i != -1; i = ne[i]) { int j = e[i]; if (j == father) continue; dfs(j, u); sz[u] += sz[j]; } }; dfs(1, -1); build(1, 1, n); while (m -- ) { int type, u, c; scanf("%d%d", &type, &u); if (type == 1) { scanf("%d", &c); modify(1, id[u], id[u] + sz[u] - 1, c); } else { LL ans = query(1, id[u], id[u] + sz[u] - 1); int sum = 0; for (int i = 0; i < 60; i ++ ) sum += (ans >> i & 1); printf("%d\n", sum); } } return 0; }