DSU on tree 樹上啟發式合併
阿新 • • 發佈:2020-09-14
DSU on tree 樹上啟發式合併
首先介紹一下大概流程:
- 首先處理所有輕鏈。
- 如果有重鏈,再處理重鏈。注意重鏈的值不刪除。
- 這樣只需要把輕鏈的貢獻算一下加上就好了,不需要處理重鏈。
- 最後,如果是輕鏈,就要刪除其對貢獻的影響。
void dfs(int u, int pre, int opt){ for(int i = head[u]; i != -1; i = nxt[i]){ int v = to[i]; if(v == pre) continue; if(v != son[u]) dfs(v, u, 0); } if(son[u]) dfs(son[u], u, 1), tson = son[u]; Addval(), tson = 0; if(!opt) Delval(); }
再換一個角度,對於每個點,如果它是父節點的重兒子,返回前就不需要刪除計算的貢獻,這樣相當於重兒子對父節點的貢獻已經算了。如果它是父節點的輕兒子,那麼返回前就需要刪除輕兒子對父節點的影響。也就是從低向上保留重兒子對當前點u的貢獻,再自上向下把輕兒子的貢獻都加上。
複雜度: \(O(nlog^{n})\)
例題: Cf600E
#include<iostream> #include<cstdio> #include<cstring> #include<algorithm> #include<cmath> #include<map> #include<queue> #include<vector> #include<string> #include<fstream> using namespace std; #define rep(i, a, n) for(int i = a; i <= n; ++ i); #define per(i, a, n) for(int i = n; i >= a; -- i); typedef long long ll; const int N = 2e5 + 105; const int mod = 1e9 + 7; const double Pi = acos(- 1.0); const ll INF = 1e18; const int G = 3, Gi = 332748118; ll qpow(ll a, ll b) { ll res = 1; while(b){ if(b & 1) res = (res * a) % mod; a = (a * a) % mod; b >>= 1;} return res; } ll gcd(ll a, ll b) { return b ? gcd(b, a % b) : a; } // bool cmp(int a, int b){return a > b;} // int n, m, A, B; int head[N], cnt = 0; int to[N << 1], nxt[N << 1]; int son[N], siz[N],dfn[N]; int col[N]; ll ans[N], sum; int num[N], Max = 0, tson; void add(int u, int v){ to[cnt] = v, nxt[cnt] = head[u], head[u] = cnt ++; to[cnt] = u, nxt[cnt] = head[v], head[v] = cnt ++; } void dfs1(int u, int pre){ siz[u] = 1; int maxx = -1; for(int i = head[u]; i != -1; i = nxt[i]){ int v = to[i]; if(v == pre) continue; dfs1(v, u); siz[u] += siz[v]; if(siz[v] > maxx){ maxx = siz[v]; son[u] = v; } } } void Addval(int u, int pre, int val){ num[col[u]] += val; if(num[col[u]] > Max) Max = num[col[u]], sum = col[u]; else if(num[col[u]] == Max) sum += col[u]; for(int i = head[u]; i != -1; i = nxt[i]){ int v = to[i]; if(v == pre || v == tson) continue; Addval(v, u, val); } } void dfs(int u, int pre, int opt){ for(int i = head[u]; i != -1; i = nxt[i]){ int v = to[i]; if(v == pre) continue; if(v != son[u]) dfs(v, u, 0); } if(son[u]) dfs(son[u], u, 1), tson = son[u]; Addval(u, pre, 1), tson = 0; ans[u] = sum; if(!opt){ Addval(u, pre, -1); sum = 0, Max = 0; } } int main() { scanf("%d",&n); cnt = 0; for(int i = 0; i <= n; ++ i) head[i] = -1; for(int i = 1; i <= n; ++ i) scanf("%d",&col[i]); for(int i = 1; i < n; ++ i){ int x, y; scanf("%d%d",&x,&y); add(x, y); } dfs1(1, 0); dfs(1, 0, 0); for(int i = 1; i <= n; ++ i){ printf("%lld",ans[i]); if(i == n) printf("\n"); else printf(" "); } return 0; }