hiho1715 樹的連通問題 動態開點線段樹 + 線段樹合併
阿新 • • 發佈:2020-10-04
hiho1715 樹的連通問題
題目連結
線段樹 + 動態開店 + 線段樹合併。
暴力\(O(n^2)\)。不可做。
我們考慮問題轉化,求每一條邊的貢獻,也就是有多少區間跨過這一條邊。
這麼寫不好寫,正難則反,我們求這個問題的對偶問題,總區間數\(-\)有多少區間沒有跨過這一條邊。
我們設當前點\(x\)的父親為\(fa\),出現在\(x\)這顆子樹裡的點標為1,沒有出現的標為0。\(fa\)也這麼標記。
總區間數很好算,\(n * (n + 1) / 2\),有多少區間沒跨過這一條邊就是這顆子樹同為1的對數和同為0的對數。
我們對每個節點搞一顆線段樹,得用動態開點,\(lmax1,rmax1,lmax0, rmax0\)
#include <bits/stdc++.h> #define mid ((l + r) >> 1) using namespace std; inline long long read() { long long s = 0, f = 1; char ch; while(!isdigit(ch = getchar())) (ch == '-') && (f = -f); for(s = ch ^ 48;isdigit(ch = getchar()); s = (s << 1) + (s << 3) + (ch ^ 48)); return s * f; } const int N = 1e5 + 5; int n, cnt, tot; long long ans; int rt[N], head[N]; struct edge { int to, nxt; } e[N << 1]; struct tree { int lc, rc; int lmax1, rmax1, lmax0, rmax0; long long sum0, sum1; } t[N << 4]; void add(int x, int y) { e[++cnt].nxt = head[x]; head[x] = cnt; e[cnt].to = y; } long long calc(int x) { return 1ll * (x + 1) * x / 2; } void modify(int o, int l, int r) { t[o].sum0 = calc(r - l + 1); t[o].lmax0 = t[o].rmax0 = r - l + 1; } void up(int o, int l, int r) { if(!t[o].lc) modify(t[o].lc, l, mid); if(!t[o].rc) modify(t[o].rc, mid + 1, r); t[o].sum0 = t[t[o].lc].sum0 + t[t[o].rc].sum0 + 1ll * t[t[o].lc].rmax0 * t[t[o].rc].lmax0; t[o].sum1 = t[t[o].lc].sum1 + t[t[o].rc].sum1 + 1ll * t[t[o].lc].rmax1 * t[t[o].rc].lmax1; t[o].lmax0 = t[t[o].lc].lmax0 == mid - l + 1 ? t[t[o].lc].lmax0 + t[t[o].rc].lmax0 : t[t[o].lc].lmax0; t[o].lmax1 = t[t[o].lc].lmax1 == mid - l + 1 ? t[t[o].lc].lmax1 + t[t[o].rc].lmax1 : t[t[o].lc].lmax1; t[o].rmax0 = t[t[o].rc].rmax0 == r - mid ? t[t[o].rc].rmax0 + t[t[o].lc].rmax0 : t[t[o].rc].rmax0; t[o].rmax1 = t[t[o].rc].rmax1 == r - mid ? t[t[o].rc].rmax1 + t[t[o].lc].rmax1 : t[t[o].rc].rmax1; } void insert(int &o, int l, int r, int x) { if(!o) o = ++tot; if(l == r) { t[o].lmax1 = t[o].rmax1 = t[o].sum1 = 1; return ; } if(x <= mid) insert(t[o].lc, l, mid, x); if(x > mid) insert(t[o].rc, mid + 1, r, x); up(o, l, r); } int merge(int x, int y, int l, int r) { if(!x || !y) return x + y; if(l == r) return t[x].sum1 ? x : y; t[x].lc = merge(t[x].lc, t[y].lc, l, mid); t[x].rc = merge(t[x].rc, t[y].rc, mid + 1, r); up(x, l, r); return x; } void dfs(int x, int fa) { for(int i = head[x]; i ; i = e[i].nxt) { int y = e[i].to; if(y == fa) continue; dfs(y, x); rt[x] = merge(rt[x], rt[y], 1, n); } insert(rt[x], 1, n, x); if(x != 1) ans += calc(n) - t[rt[x]].sum0 - t[rt[x]].sum1; } int main() { n = read(); for(int i = 1, x, y;i <= n - 1; i++) x = read(), y = read(), add(x, y), add(y, x); dfs(1, 0); printf("%lld", ans); return 0; }