dsu on tree 入門
Dus on tree
樹上並查集?。 啊這,並不是的啦,他的全稱是樹上啟發式合併。
他主要解決不帶修改且主要詢問子樹資訊的樹上問題。
先來看到例題,CF600E 。
這不就是樹上莫隊的經典題嗎?。 會莫隊的大佬一眼就秒了。
不會的蒟蒻我只能打打暴力,騙騙分。
首先,我們暴力其實很好打,就是對每個點都統計一下他子樹的答案,時間複雜度為 O(\(n^2\)).
這顯然,我們是不能接受的,我們需要優化。
Dus on tree 利用輕重鏈剖分的思想,把他的複雜度優化為 O(\(n log n\)) 的。
我們遞迴處理子樹的時候,重複計算了很多種狀態,我們就要考慮剪枝,減去重複的狀態。
實際上,我們最後處理的子樹肯定是不需要刪除他的貢獻的,在計算他的父親 \(x\)
就可以直接推出他父親的答案。
我們就要確定一種順序,使我們遍歷子樹節點的數目儘可能少。
我們選的最後遍歷的要為他的重兒子(重兒子子樹中的節點是最多的)
就這樣 Dus on tree 的演算法流程就是:
-
遞迴處理每個輕兒子,同時消除輕兒子的影響。
-
遞迴重兒子,不消除重兒子的影響。
-
統計輕兒子對答案的貢獻。
-
將輕兒子和重兒子的資訊合併,得出這個節點的答案。
-
消除輕兒子對答案的影響
大致的程式碼張這樣:
void dfs(int x,int fa,int opt) { for(int i = head[x]; i; i = e[i].net) { int to = e[i].to; if(to == fa || to == son[x]) continue; dfs(to,1);//遞迴輕兒子 } if(son[x]) dfs(son[x],0);//帝國重兒子 add(x);//統計輕兒子對答案的貢獻 ans[x] = now_ans;//合併輕兒子和重兒子得出這個點的答案 if(opt == 1) delet(x);//如果他是輕兒子,消除他的影響 }
圖例長這樣
紫色的是他的輕兒子,紅色的是他的重兒子,序號是他的遍歷順序。
這不是和普通的爆搜沒什麼區別嗎?為什麼複雜度不是O(\(n^2\))
下面,我們簡單證明一下他的複雜度,不願意看的可以直接跳過。
性質:一個節點到根的路徑上輕邊個數不會超過 logn 條
證明:設根到該節點有 \(x\) 的輕邊,該節點的大小為 \(y\),根據輕重邊的定義,輕邊所連向的點的大小不會成為該節點
總大小的一般。這樣每經過一條輕邊,\(y\) 的上限就會 /2,所以 \(x\) < \(log n\)
然而這條性質並不能解決問題,我們考慮一個點會被訪問多少次
一個點被訪問到,只有兩種情況
1、在暴力統計輕邊的時候訪問到。根據前面的性質,該次數 < \(log n\)
2、通過重邊 在遍歷的時候被訪問到,顯然只有一次
如果統計一個點的貢獻的複雜度為O(1)的話,該演算法的複雜度為O(\(nlog n\))
我們上面的例題就可以直接套板子了。
Code
#include<iostream>
#include<cstdio>
#include<algorithm>
using namespace std;
#define int long long
const int N = 1e5+10;
int n,m,tot,u,v,max_c,now_ans,heavy_son;
int siz[N],fa[N],son[N],head[N],cnt[N],c[N],ans[N];
struct node
{
int to,net;
}e[N<<1];
void add_(int x,int y)
{
e[++tot].to = y;
e[tot].net = head[x];
head[x] = tot;
}
inline int read()
{
int s = 0,w = 1; char ch = getchar();
while(ch < '0' || ch > '9'){if(ch == '-') w = -1; ch = getchar();}
while(ch >= '0' && ch <= '9'){s =s * 10+ch - '0'; ch = getchar();}
return s * w;
}
void get_tree(int x)
{
siz[x] = 1;
for(int i = head[x]; i; i = e[i].net)
{
int to = e[i].to;
if(to == fa[x]) continue;
fa[to] = x;
get_tree(to);
siz[x] += siz[to];
if(siz[son[x]] < siz[to]) son[x] = to;
}
}
void add(int x,int val)
{
cnt[c[x]] += val;
// printf("----------->\n");
// cout<<max_c<<" "<<now_ans<<endl;
if(cnt[c[x]] > max_c)
{
max_c = cnt[c[x]];
now_ans = c[x];
}
else if(cnt[c[x]] == max_c)
{
now_ans += c[x];
}
for(int i = head[x]; i; i = e[i].net)
{
int to = e[i].to;
if(to == fa[x] || to == heavy_son) continue;
add(to,val);
}
}
void dfs(int x,int type)
{
for(int i = head[x]; i; i = e[i].net)
{
int to = e[i].to;
if(to == fa[x] || to == son[x]) continue;
dfs(to,1);
}
if(son[x])
{
dfs(son[x],0);
heavy_son = son[x];
}
add(x,1);
ans[x] = now_ans;
heavy_son = 0;
if(type == 1)
{
add(x,-1);
max_c = now_ans = 0;
}
}
signed main()
{
n = read();
for(int i = 1; i <= n; i++) c[i] = read();
for(int i = 1; i <= n-1; i++)
{
u = read(); v = read();
add_(u,v); add_(v,u);
}
get_tree(1); dfs(1,0);
for(int i = 1; i <= n; i++) printf("%lld ",ans[i]);
return 0;
}