1. 程式人生 > 實用技巧 >codeforces 600E - Lomsat gelral (dsu on tree)

codeforces 600E - Lomsat gelral (dsu on tree)

題目連結:https://codeforces.com/problemset/problem/600/E

一直沒有點這個技能點,今天跟隊友打訓練賽,碰到一道 \(dsu\ on\ tree\) 的題寫不出來,就回來把這個題寫了

\(dsu\ on\ tree\) 運用了輕重鏈剖分的思想,先處理輕兒子的答案,然後消去輕兒子的影響,最後處理重兒子,並保留重兒子的答案,

時間複雜度 \(O(nlogn)\)

#include<iostream>
#include<algorithm>
#include<cstring>
#include<cstdio>
#include<cmath>
using namespace std;
typedef long long ll;

const int maxn = 200010;

int n;

int h[maxn], cnt = 0;
struct E{
	int to, next;
}e[maxn << 1];
void add(int u, int v){
	e[++cnt].next = h[u];
	e[cnt].to = v;
	h[u] = cnt;
}

ll sum, ans[maxn];
int mx, Son, sz[maxn], son[maxn], c[maxn], cn[maxn * 10];

void dfs1(int u, int par){
	sz[u] = 1;
	int mx = 0;
	for(int i = h[u] ; i != -1; i = e[i].next){
		int v = e[i].to;
		if(v == par) continue;
		dfs1(v, u);
		if(sz[v] > mx){
			mx = sz[v];
			son[u] = v;
		} 
		sz[u] += sz[v]; 
	}
}

void update(int u, int par, int val){
	cn[c[u]] += val;
	if(cn[c[u]] > mx) mx = cn[c[u]], sum = c[u];
	else if(cn[c[u]] == mx) sum += c[u];
	for(int i = h[u] ; i != -1 ; i = e[i].next){
		int v = e[i].to; 
		if(v == par || v == Son) continue;
		update(v, u, val);
	} 
}

void dfs2(int u, int par, int is){ // is: 是不是重兒子 
	//處理兒子的答案
	// 先處理輕兒子 處理完要消除影響 
	for(int i = h[u] ; i != -1 ; i = e[i].next){
		int v = e[i].to;
		if(v == par || v == son[u]) continue;
		dfs2(v, u, 0);
	}

	// 再處理重兒子 重兒子的答案保留 
 	sum = 0, mx = 0;
	if(son[u]) dfs2(son[u], u, 1); 
	Son = son[u];

//	printf("%d\n", u);
//	for(int i = 1 ; i <= n ; ++i){
//		printf("%d ", cn[i]);
//	} printf("\n");
	update(u, par, 1); Son = 0;
//	for(int i = 1 ; i <= n ; ++i){
//		printf("%d ", cn[i]);
//	} printf("\n");
	
	ans[u] = sum;
	
	if(!is) update(u, par, -1);
}

ll read(){ ll s = 0, f = 1; char ch = getchar(); while(ch < '0' || ch > '9'){ if(ch == '-') f = -1; ch = getchar(); } while(ch >= '0' && ch <= '9'){ s = s * 10 + ch - '0'; ch = getchar(); } return s * f; }

int main(){
	memset(h, -1, sizeof(h));
	n = read(); 
	for(int i = 1 ; i <= n ; ++i){
		c[i] = read(); 
	}
	
	int u, v;
	for(int i = 1 ; i < n ; ++i){
		u = read(), v = read();
		add(u, v), add(v, u);
	}
	
	dfs1(1, 0);
	
//	for(int i = 1 ; i <= n ; ++i) printf("%d ", son[i]); printf("\n");
	dfs2(1, 0, 1);
	
	for(int i = 1 ; i <= n ; ++i){
		printf("%lld ", ans[i]);
	} printf("\n");
	
	return 0;
}