1. 程式人生 > >POJ 1741 Tree【Tree,點分治】

POJ 1741 Tree【Tree,點分治】

樹上的演算法真的很有意思……哈哈。

給一棵邊帶權樹,問兩點之間的距離小於等於K的點對有多少個。

將無根樹轉化成有根樹進行觀察。滿足條件的點對有兩種情況:兩個點的路徑橫跨樹根,兩個點位於同一顆子樹中。

如果我們已經知道了此時所有點到根的距離a[i],a[x] + a[y] <= k的(x, y)對數就是結果,這個可以通過排序之後O(n)的複雜度求出。然後根據分治的思想,分別對所有的兒子求一遍即可,但是這會出現重複的——當前情況下兩個點位於一顆子樹中,那麼應該將其減掉(顯然這兩個點是滿足題意的,為什麼減掉呢?因為在對子樹進行求解的時候,會重新計算)。

在進行分治時,為了避免樹退化成一條鏈而導致時間複雜度變為O(N^2),每次都找樹的重心,這樣,所有的子樹規模就會變的很小了。時間複雜度O(Nlog^2N)。

樹的重心的演算法可以線性求解。

#include <cstdio>
#include <algorithm>
#include <vector>
#include <cstring>
using namespace std;
#define N 10009
struct node {
    int v, l;
    node() {};
    node(int _v, int _l): v(_v), l(_l) {};
};
vector<node> g[N];
int n, k, size, s[N], f[N], root, d[N], K, ans;
vector<int> dep;
bool done[N];
void getroot(int now, int fa) {
    int u;
    s[now] = 1; f[now] = 0;
    for (int i=0; i<g[now].size(); i++)
        if ((u = g[now][i].v) != fa && !done[u]) {
            getroot(u, now);
            s[now] += s[u];
            f[now] = max(f[now], s[u]);
        }
    f[now] = max(f[now], size-s[now]);
    if (f[now] < f[root]) root = now;
}
void getdep(int now, int fa) {
    int u;
    dep.push_back(d[now]);
    s[now] = 1;
    for (int i=0; i<g[now].size(); i++)
        if ((u = g[now][i].v) != fa && !done[u]) {
            d[u] = d[now] + g[now][i].l;
            getdep(u, now);
            s[now] += s[u];
        }
}
int calc(int now, int init) {
    dep.clear(); d[now] = init;
    getdep(now, 0);
    sort(dep.begin(), dep.end());
    int ret = 0;
    for (int l=0, r=dep.size()-1; l<r; )
        if (dep[l] + dep[r] <= K) ret += r-l++;
        else r--;
    return ret;
}
void work(int now) {
    int u;
    ans += calc(now, 0);
    done[now] = true;
    for (int i=0; i<g[now].size(); i++)
        if (!done[u = g[now][i].v]) {
            ans -= calc(u, g[now][i].l);
            f[0] = size = s[u];
            getroot(u, root=0);
            work(root);
        }
}
int main() {

    while (scanf("%d%d", &n, &K) == 2) {
        if (n == 0 && K == 0) break;
        for (int i=0; i<=n; i++) g[i].clear();
        memset(done, false, sizeof(done));

        int u, v, l;
        for (int i=1; i<n; i++) {
            scanf("%d%d%d", &u, &v, &l);
            g[u].push_back(node(v, l));
            g[v].push_back(node(u, l));
        }
        f[0] = size = n;
        getroot(1, root=0);
        ans = 0;
        work(root);
        printf("%d\n", ans);
    }
    return 0;
}