1. 程式人生 > 其它 >gym 102331 F. Fast Spanning Tree

gym 102331 F. Fast Spanning Tree

https://codeforces.com/gym/102331/problem/F

學到許多

首先有個顯然的性質,假設一條邊兩邊的權值和分別是\(x,y\),這條邊的要求是\(S\)

\(x+y \ge S\) 可以得到
\(x \ge \frac{S}{2} \ \ or \ \ y \ge \frac{S}{2}\)

我們把這條邊的限制(上界)平均分別丟到\(x, y\)上,拿一個小根堆維護
\(x+\frac{s-x-y}{2}\)

當啟發式合併兩個聯通塊\(u,v(siz[u]<siz[v])\)

\(a[v]+=a[u]\),把\(u\)的小根堆合併進去

然後檢視堆頂,如果\(a[v]\ge\)

堆頂的限制,那麼說明\(v\)已經達到堆頂那條邊的限制之一了,假設堆頂那條邊是\((x,y,s)\), 那麼只需要判斷如果\(a[x]+a[y]\ge s\)就加入答案,否則把剩下的\(s-a[x]-a[y]\)再平均分配到\(x,y\)兩個點上面的堆裡,作為新的限制。

即新的上界限制是\(a[x]+\frac{s-a[x]-a[y]}{2}\),\(y\)同理

啟發式合併的時間複雜度是\(O(nlog^2n)\)的,然後發現每條邊最多被訪問\(logC\)
容易發現這樣均攤下來時間複雜度是\(O(nlog^2n+nlogC)\)

完全能跑

可以結合程式碼理解

code:

#include<bits/stdc++.h>
#define N 400050
#define fi first
#define se second
using namespace std;
const int inf = 1e9;
struct E {
    int u, v, c;
} e[N << 1];
int fa[N], a[N], ans[N], gs;
priority_queue<pair<int, int> > q[N];
priority_queue<int> ok;
int get(int x) {
    return x == fa[x]? x : (fa[x] = get(fa[x]));
}
void add(int i) { //printf("* %d\n", i);
    int x = get(e[i].u), y = get(e[i].v);
    if(x == y) return ;
    if(a[x] + a[y] >= e[i].c) {
        ok.push(- i);
        return ;
    }
    int o = (e[i].c - a[x] - a[y] + 1) / 2;
    q[x].push(make_pair(-(a[x] + o), i));
    q[y].push(make_pair(-(a[y] + o), i)); 
}
void merge(int u, int v, int i) {
  // printf("%d %d %d\n", u, v, i);
    u = get(u), v = get(v);
    if(u == v) return ;
    ans[++ gs] = i;
    if(q[u].size() > q[v].size()) swap(u, v);

    fa[u] = v; a[v] = a[u] + a[v];
    if(a[v] > inf) a[v] = inf;
    while(q[u].size()) {
        auto x = q[u].top(); q[u].pop(); 
        if(-x.fi <= a[v]) add(x.se);
        else q[v].push(x);
    }
    while(q[v].size()) {
        auto x = q[v].top(); q[v].pop();
      //  printf("%d %d\n", x.fi, x.se);
        if(-x.fi <= a[v]) add(x.se);
        else {
            q[v].push(x);
            break;
        }
    }
}
int n, m;
int main() {
    scanf("%d%d", &n, &m);
    for(int i = 1; i <= n; i ++) scanf("%d", &a[i]);
    for(int i = 1; i <= n; i ++) fa[i] = i;
    for(int i = 1; i <= m; i ++) {
        scanf("%d%d%d", &e[i].u, &e[i].v, &e[i].c);
        add(i);
    }
    while(ok.size()) {
        int i = -ok.top(); ok.pop();
        merge(e[i].u, e[i].v, i);
    }
    //sort(ans + 1, ans + 1 + gs);
    printf("%d\n", gs);
    for(int i = 1; i <= gs; i ++) printf("%d ", ans[i]);
    return 0;
}