1. 程式人生 > 其它 >CF161D Distance in Tree 題解

CF161D Distance in Tree 題解

Description

洛谷傳送門

Solution

似乎各種做法都可以過,這裡提供一篇 \(dsu\ on\ tree\) (樹上啟發式合併)的題解。

不會的同學可以看我的部落格 淺談 dsu on tree

題目要求我們求出長度為 \(k\) 的路徑有多少條。

那麼我們可以開一個桶 \(cnt_x\),表示深度為 \(x\) 的點有多少個,統計答案時 \(ans += cnt_{k - dep[x] + 2 * dep[topx]}\) (類似於樹上差分的思想)。

然後修改就比較板子了,加入一個點的話就 \(cnt_{dep[x]}++\),刪除的話就 \(cnt_{dep[x]}--\)

其他的就沒有什麼了。

具體看程式碼吧。

Code

#include <iostream>
#include <cstdio>
#include <algorithm>
#define ll long long

using namespace std;

const int N = 6e4 + 10;
int n, k;
ll ans;
struct node{
    int v, nxt;
}edge[N << 1];
int head[N], tot;
int siz[N], son[N], fa[N], dep[N];
int cnt[N];

inline void add(int x, int y){
    edge[++tot] = (node){y, head[x]};
    head[x] = tot;
}

inline void dfs(int x, int p){
    dep[x] = dep[p] + 1, siz[x] = 1, fa[x] = p;
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y == p) continue;
        dfs(y, x);
        siz[x] += siz[y];
        if(!son[x] || siz[y] > siz[son[x]])
            son[x] = y;
    }
}

inline void update(int x, int topfa, int type){//type = 0: 加入   1: 刪除   2: 統計答案
    if(!type) cnt[dep[x]]++;
    else if(type == 1) cnt[dep[x]]--;
    else if(k - dep[x] + (dep[topfa] << 1) >= 0) ans += (ll)cnt[k - dep[x] + (dep[topfa] << 1)];//這裡前面要判一下,不然會 RE
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y != fa[x])
            update(y, topfa, type);
    }
}

inline void solve(int x, int type){//type = 1 表示是重兒子,type = 0 表示是輕兒子
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y != son[x] && y != fa[x]) solve(y, 0);
    }
    if(son[x]) solve(son[x], 1);//加入重兒子
    ans += (ll)cnt[dep[x] + k];//從當前點向下 k 個單位
    cnt[dep[x]]++;//加入根節點
    for(int i = head[x]; i; i = edge[i].nxt){
        int y = edge[i].v;
        if(y == fa[x] || y == son[x]) continue;
        update(y, x, 2);//統計答案
        update(y, x, 0);//加入輕兒子
    }
    if(!type) update(x, x, 1);//刪除輕兒子
}

int main(){
    scanf("%d%d", &n, &k);
    for(int i = 1; i < n; ++i){
        int u, v;
        scanf("%d%d", &u, &v);
        add(u, v), add(v, u);
    }
    dfs(1, 0);
    solve(1, 0);
    printf("%lld\n", ans);
    return 0;
}

End

本文來自部落格園,作者:xixike,轉載請註明原文連結:https://www.cnblogs.com/xixike/p/15473244.html