1. 程式人生 > 其它 >樹上啟發式合併(dsu on tree)精巧的暴力

樹上啟發式合併(dsu on tree)精巧的暴力

樹上啟發式合併(dsu on tree)精巧的暴力

樹上啟發式合併(dsu on tree)

雖然叫dsu但這和並查集貌似沒什麼關係

例:

給你一棵樹,每個節點有一個顏色,要求出每個子樹中數量最多的顏色並輸出

(數量相同的情況先不考慮 不重要

當我們需要在每個子樹上統計一些資訊的時候,往往會開一個全域性的cnt陣列,試圖 dfs \(O(n)\) 掃一遍,一邊加點一邊得到答案

但對於一棵樹而言顯然有問題:當我們統計完其左子樹的資訊後,必須清空整個cnt陣列才能去掃右子樹,這樣其實就已經變成 \(O(n^2)\)

當然我們可以稍微偷工減料一點,因為最後一棵子樹統計完後不用清空,我們可以最後遍歷最大的那棵子樹

最大子樹可以通過一遍dfs預處理出子樹的size,記錄每個點的重兒子得到(類似樹剖)

然而就是這一點偷工減料,使得整個演算法複雜度直接降至 \(O(nlogn)\)

如果不關心證明的話,你已經學會 dsu on tree 了

證明為什麼這樣瞎搞就能獲得\(nlogn\)​的複雜度

以下通過感性理解的方式說明為什麼這東西能優化這麼多

回顧一下在每個節點處我們要做什麼:

  • dfs輕兒子,並消除影響
  • dfs重兒子,不消除影響
  • 再統計輕子樹的影響

前兩步的操作一共是 \(O(n)\) 的,就是最樸素的從頭到尾掃一遍

現在需要考慮:在每個點處對每個輕子樹掃一遍的複雜度

如果一個點和根節點之間一共有 x 條輕邊,那麼它會被遍歷差不多 x+1 次

而輕重鏈剖分有個很好的性質:走一條輕邊時,節點數量至少被砍一半

,否則這就不是輕邊了

那麼從根節點到任意節點經過的輕邊數量最多是 \(logn\)​ 級別的

所以其實很顯然了:複雜度就是 \(O(nlogn)\)

再看看極端情況加深理解:

樹上問題最容易被出題人的各種鏈,菊花圖,鏈加菊花圖啥的卡掉

如果這棵樹長得像鏈,它將被最後走最大子樹這一小貪心優化掉一大半;

如果這棵樹長得像菊花圖,,那麼根節點到任意節點間的輕邊數量都將是極少的;

所以你可以相信dsu on tree

程式碼(這道題的)

int n;
int col[maxn];
int cnt[maxn];
ll ans[maxn];
int siz[maxn], son[maxn];

struct Edge{
    int t, nt;
}e[maxn*2];

int hd[maxn], ecnt = 0;

inline void add(int x, int y){
    e[++ecnt].t = y;
    e[ecnt].nt = hd[x];
    hd[x] = ecnt;
}

void dfs1(int p, int fa){
    siz[p] = 1;
    son[p] = 0;
    for(int i=hd[p];i;i=e[i].nt){
        int v = e[i].t;
        if(v!=fa){
            dfs1(v, p);
            siz[p] += siz[v];
            if(siz[v] > siz[son[p]]) son[p] = v;
        }
    }
}

ll tot = 0, mxc = 0;
void addcol(int c, int ad){//只計加不計減(減肯定減到0)
    cnt[c] += ad;
    if(cnt[c] > mxc){
        mxc = cnt[c];
        tot = c;
    }else if(cnt[c] == mxc){
        tot += c;
    }
}

void cntall(int p, int fa, int d){
    for(int i=hd[p];i;i=e[i].nt){
        int v = e[i].t;
        if(v!=fa){
            cntall(v, p, d);
        }
    }
    addcol(col[p], d);
}

void dfs(int p, int fa, int sav){
    for(int i=hd[p];i;i=e[i].nt){
        int v = e[i].t;
        if(v!=fa && v!=son[p]){
            dfs(v, p, 0);
        }
    }
    if(son[p]) dfs(son[p], p, 1);
    for(int i=hd[p];i;i=e[i].nt){
        int v = e[i].t;
        if(v!=fa && v!=son[p]){
            cntall(v, p, 1);
        }
    }
    //此時所有子節點均已記錄
    addcol(col[p], 1);
    ans[p] = tot;
    if(!sav) cntall(p, fa, -1), tot = mxc = 0;
}

void solve(){
    cin >> n;
    for(int i=1;i<=n;i++) cin >> col[i];
    for(int i=1;i<n;i++){
        int x, y;
        cin >> x >> y;
        add(x, y); add(y, x);
    }
    dfs1(1, -1);
    dfs(1, -1, 1);
    for(int i=1;i<n;i++) cout << ans[i] << ' ';
    cout << ans[n] << '\n';
}