樹上啟發式合併(dsu on tree)精巧的暴力
阿新 • • 發佈:2021-08-12
樹上啟發式合併(dsu on tree)精巧的暴力
證明
,否則這就不是輕邊了
樹上啟發式合併(dsu on tree)
雖然叫dsu但這和並查集貌似沒什麼關係
例:
給你一棵樹,每個節點有一個顏色,要求出每個子樹中數量最多的顏色並輸出
(數量相同的情況先不考慮
不重要)
當我們需要在每個子樹上統計一些資訊的時候,往往會開一個全域性的cnt陣列,試圖 dfs \(O(n)\) 掃一遍,一邊加點一邊得到答案
但對於一棵樹而言顯然有問題:當我們統計完其左子樹的資訊後,必須清空整個cnt陣列才能去掃右子樹,這樣其實就已經變成 \(O(n^2)\) 了
當然我們可以稍微偷工減料一點,因為最後一棵子樹統計完後不用清空,我們可以最後遍歷最大的那棵子樹
最大子樹可以通過一遍dfs預處理出子樹的size,記錄每個點的重兒子得到(類似樹剖)
然而就是這一點偷工減料,使得整個演算法複雜度直接降至 \(O(nlogn)\)
如果不關心證明的話,你已經學會 dsu on tree 了
證明為什麼這樣瞎搞就能獲得\(nlogn\)的複雜度:
以下通過感性理解的方式說明為什麼這東西能優化這麼多
回顧一下在每個節點處我們要做什麼:
- dfs輕兒子,並消除影響
- dfs重兒子,不消除影響
- 再統計輕子樹的影響
前兩步的操作一共是 \(O(n)\) 的,就是最樸素的從頭到尾掃一遍
現在需要考慮:在每個點處對每個輕子樹掃一遍的複雜度
如果一個點和根節點之間一共有 x 條輕邊,那麼它會被遍歷差不多 x+1 次
而輕重鏈剖分有個很好的性質:走一條輕邊時,節點數量至少被砍一半
那麼從根節點到任意節點經過的輕邊數量最多是 \(logn\) 級別的
所以其實很顯然了:複雜度就是 \(O(nlogn)\)
再看看極端情況加深理解:
樹上問題最容易被出題人的各種鏈,菊花圖,鏈加菊花圖啥的卡掉
如果這棵樹長得像鏈,它將被最後走最大子樹這一小貪心優化掉一大半;
如果這棵樹長得像菊花圖,,那麼根節點到任意節點間的輕邊數量都將是極少的;
所以你可以相信dsu on tree
程式碼(這道題的)
int n; int col[maxn]; int cnt[maxn]; ll ans[maxn]; int siz[maxn], son[maxn]; struct Edge{ int t, nt; }e[maxn*2]; int hd[maxn], ecnt = 0; inline void add(int x, int y){ e[++ecnt].t = y; e[ecnt].nt = hd[x]; hd[x] = ecnt; } void dfs1(int p, int fa){ siz[p] = 1; son[p] = 0; for(int i=hd[p];i;i=e[i].nt){ int v = e[i].t; if(v!=fa){ dfs1(v, p); siz[p] += siz[v]; if(siz[v] > siz[son[p]]) son[p] = v; } } } ll tot = 0, mxc = 0; void addcol(int c, int ad){//只計加不計減(減肯定減到0) cnt[c] += ad; if(cnt[c] > mxc){ mxc = cnt[c]; tot = c; }else if(cnt[c] == mxc){ tot += c; } } void cntall(int p, int fa, int d){ for(int i=hd[p];i;i=e[i].nt){ int v = e[i].t; if(v!=fa){ cntall(v, p, d); } } addcol(col[p], d); } void dfs(int p, int fa, int sav){ for(int i=hd[p];i;i=e[i].nt){ int v = e[i].t; if(v!=fa && v!=son[p]){ dfs(v, p, 0); } } if(son[p]) dfs(son[p], p, 1); for(int i=hd[p];i;i=e[i].nt){ int v = e[i].t; if(v!=fa && v!=son[p]){ cntall(v, p, 1); } } //此時所有子節點均已記錄 addcol(col[p], 1); ans[p] = tot; if(!sav) cntall(p, fa, -1), tot = mxc = 0; } void solve(){ cin >> n; for(int i=1;i<=n;i++) cin >> col[i]; for(int i=1;i<n;i++){ int x, y; cin >> x >> y; add(x, y); add(y, x); } dfs1(1, -1); dfs(1, -1, 1); for(int i=1;i<n;i++) cout << ans[i] << ' '; cout << ans[n] << '\n'; }