K - Tree 2019icpc南昌K題 (樹上啟發式合併 dsu on tree)
題目連結:https://nanti.jisuanke.com/t/42586
題意:給一棵n個節點的樹,編號1-n,每個點有點權w,問有多少組節點(x,y),滿足:
1.x != y
2.x點不是y點的祖先,y點不是x點的祖先
3.x點和y點的最短距離<=k (看了題解才知道,up to k 原來是小於等於k的意思???)
4.設x和y點的公共祖先是z,val【i】表示 i 節點的點權,要求val【z】 * 2 = val【x】 + val【y】
思路:樹上啟發式合併(dsu on tree)
為了深刻記憶,所以趁著剛學會一點,分享下自己理解的樹上啟發式合併。(正好用這道題來講講)
先貼個流程,下面會詳細說為什麼,圖搬運自大佬的部落格:https://blog.csdn.net/qq_44341728/article/details/102825145
對於這個題目,先考慮暴力的做法
假設k = 3,所有點的點權為1,看下圖
假設要算1號節點為公共節點的貢獻,我們觀察,2號節點和6號節點顯然不能滿足條件,因為他們在一條鏈上,但是2號節點和3,4,5號節點是可以湊出貢獻的,也就是2號和紅色圈圈住的子樹的每一個節點都有可能湊出貢獻。
這樣就有了一個想法,如果我們知道紅色圈住的子樹的資訊,我就可以直接算出左邊每個點貢獻(2節點和6節點)
這裡用若干個線段樹來維護資訊,每個權值都開一棵權值線段樹,維護某個權值在某個深度出現的次數。
舉個例子(上面有假設所有點權值為1),假如1號節點深度為1,那麼3號節點深度為2, 4,5號節點深度為3,那麼對於權值為1的線段樹,維護的資訊就是:深度為2的點有1個,深度為3的點有2個。
算2號節點的貢獻時,算出另一個匹配節點的權值應該是val【1】 * 2 - val【2】 = 1 * 2 - 1 = 1, 深度最大是:k + 2 * dep【1】 - dep【2】 = 3 + 2 - 2 = 3
所以就應該查紅框的子樹內,權值為1的線段樹,深度區間在【1,3】的點有多少個,這裡查到3個(即3,4,5號節點)。(關鍵)
PS:提一嘴深度最大值怎麼算,假設z是x和y的最近公共祖先,則x和y的距離 = dep【x】 + dep【y】 - 2 * dep【z】,題目要求 x和y的距離<=k, 所以換個位置就是:dep【y】<= k + dep【z】 * 2 - dep【x】
算6號節點的貢獻時,同理,應該查紅框的子樹內,權值為1的線段樹,深度區間在【1,2】的點有多少個,這裡查到1個(即3號節點)。
那麼對於3號節點為根的子樹來說,統計方法也一樣,只要我知道2號節點為根子樹資訊,可以用一樣的方法來算。
有人可能會問:那4號和5號節點怎麼統計?他們會在算3號節點為公共節點的時候算,因為算最大深度的時候需要用到最近公共祖先,所以兩個點若有貢獻,那這兩個點都應該在不同的子樹上。
重點來了,暴力的做法就是對於每個節點,我都維護這個節點為根的子樹的所有資訊,即n個權值的線段樹,顯然空間爆炸,直接MLE
那麼考慮在全域性開n個權值的線段樹,每個節點都用全域性的線段樹來維護資訊,但也是空間爆炸。所以考慮動態開點,每棵線段樹只開遍歷到的點。
空間的問題解決了,然後考慮時間,因為只有全域性的線段樹,他是所有節點共享的,做答案統計的時候要確保使用的時候資料是對的。
先來說明一下為什麼會有資料對不對的問題,遞迴地往下跑,假如先跑到2號節點,把2號節點資訊更新到線段樹裡面,然後遞歸回去跑3號節點子樹。
當統計3號節點為根的答案時,假設已經跑完了以5號節點為根的子樹,現在要計算4號節點的貢獻,按照上面的思路,就應該找5號節點為根子樹,查某個權值某個區間有多少個點。
關鍵的地方來了,因為2號節點的資料已經更新到線段樹裡了,如果不做處理直接查,那就會出問題,查的資訊都不對了。
如果暴力的解決這個問題,就是每次使用線段樹的時候,都先清空,然後跑對應的子樹每個節點,更新到線段樹上,最後再查詢。
這麼做顯然n方,時間不允許,重點又來了,這裡就正式開始介紹樹上啟發式合併了,他可以把時間優化到n*logn。(確實囉嗦了點,但為了照顧像我一樣的小白,就決定說得仔細一些......)
注意到,當統計以3號節點為根節點的答案時,除了他子樹包含的點,其他點都毫無用處,即2號節點的資訊此時不應該出現線上段樹裡。
那麼我們把2號節點的資訊刪掉不就行?當我統計完3號節點的答案後,我再把2號節點的資訊加回線段樹裡面,這不就完美了。
然後再看看時間複雜度,如果先統計2號節點為根的子樹,再統計3號節點為根的子樹,那麼操作就是:
跑2號節點為根的子樹,期間把子樹每個節點都更新到線段樹上,跑完後線上段樹上刪除子樹的每個節點(為了消除都其他子樹的影響),然後跑3號節點為根的子樹,期間把子樹每個節點都更新到線段樹上,跑完後發現1號節點的子樹都跑完了,結束遞迴。那麼發現,3號節點的子樹資訊就不需要刪除了,也就是說,如果一個節點i有n個子樹,可以選擇一個子樹只跑一次(加資訊),其他子樹都要跑三次(加資訊(統計答案)+刪資訊(消除對其他子樹的影響)+加資訊(維護i節點子樹的資訊,遞迴出去要給其他節點用))
重點又雙叒叕來了!
根據上面的分析,顯然要選一棵最大的子樹最後跑,會使得時間最優,這就是樹上啟發式合併的關鍵思想,並且這樣就可以使得時間變為nlogn。
實際上就是先跑輕兒子及其子樹,再跑重兒子及其子樹。(重兒子:父親節點的所有兒子中子樹結點數目最多(size最大)的結點,輕兒子就是除重兒子之外的其他節點,樹鏈剖分的內容)
樹上啟發式合併就學完了!除了換了一下遍歷兒子的順序,省了一次消影響(重兒子的影響不用減了),好像與暴力沒其它區別了!!
但這個優化就讓整個時間複雜度降到了嚴格 n*logn,而且可以如下證明:(證明也是搬運別人部落格的:https://blog.csdn.net/qq_44341728/article/details/102825145)
程式碼:
#include<bits/stdc++.h> #define ll long long using namespace std; const int maxn = 1e5 + 7; int sz[maxn],val[maxn],dep[maxn],son[maxn]; int T[maxn],ls[maxn*200],rs[maxn*200],tr[maxn*200],cnt,n,k; ll ans; vector<int>E[maxn]; void dfs1(int u) {//預處理每個節點的大小sz,深度dep和每個節點的重兒子son sz[u] = 1; for (auto v:E[u]) { dep[v] = dep[u] + 1; dfs1(v); sz[u] += sz[v]; if(sz[v] > sz[son[u]]) son[u] = v; } } void update(int &rt,int l,int r,int pos,int c) { if(!rt) rt = ++cnt; // 注意,動態開點 tr[rt] += c; if(l == r) return ; int mid = l + r >> 1; if(pos<=mid) update(ls[rt],l,mid,pos,c); if(mid<pos) update(rs[rt],mid+1,r,pos,c); } ll query(int rt,int l,int r,int L,int R) { if(!rt) return 0; if(L<=l && r<=R) return tr[rt]; int mid = l + r >> 1; ll ans = 0; if(L<=mid) ans += query(ls[rt],l,mid,L,R); if(mid<R) ans += query(rs[rt],mid+1,r,L,R); return ans; } void add(int u) { update(T[val[u]],1,n,dep[u],1); for (auto v:E[u]) add(v); } void del(int u) {//刪除u節點及其子樹線上段樹上的資訊 update(T[val[u]],1,n,dep[u],-1); for (auto v:E[u]) del(v); } void gao(int u,int fa) { int d = k + 2 * dep[fa] - dep[u];//最大深度 int w = 2 * val[fa] - val[u];//另一個點的點權 d = min(d,n); if(w >= 0 && w <= n) ans += query(T[w],1,n,1,d); for (auto v:E[u]) gao(v,fa);//子樹的每個點都要暴力統計 } void dfs2(int u) { // 樹上啟發式合併 for (auto v:E[u]) { // 1.先跑輕兒子及其子樹,跑完後暴力刪除 if(v == son[u]) continue; dfs2(v);//跑輕兒子v及其子樹 del(v);//刪除輕兒子v及其子樹 } if(son[u]) dfs2(son[u]);//2.跑重兒子,不刪 for (auto v:E[u]) {//3.把所有輕兒子都加回來 if(v == son[u]) continue; gao(v,u);//統計答案 (以u為根節點,其中一個點在v及其 子樹) add(v); } update(T[val[u]],1,n,dep[u],1);//把自己也加上 } int main() { int x; scanf("%d%d",&n,&k); for (int i=1; i<=n; ++i) { scanf("%d",&val[i]); //每個點的點權val } for (int i=2; i<=n; ++i) { scanf("%d",&x); E[x].push_back(i); } dep[1] = 1; dfs1(1); dfs2(1);//關鍵看這裡 printf("%lld",ans * 2); return 0; }