1. 程式人生 > 實用技巧 >異象石(引理證明)

異象石(引理證明)

題面

Adera是Microsoft應用商店中的一款解謎遊戲。

異象石是進入Adera中異時空的引導物,在Adera的異時空中有一張地圖。

這張地圖上有N個點,有N-1條雙向邊把它們連通起來。

起初地圖上沒有任何異象石,在接下來的M個時刻中,每個時刻會發生以下三種類型的事件之一:

地圖的某個點上出現了異象石(已經出現的不會再次出現;
地圖某個點上的異象石被摧毀(不會摧毀沒有異象石的點;
向玩家詢問使所有異象石所在的點連通的邊集的總長度最小是多少。
請你作為玩家回答這些問題。

輸入格式

第一行有一個整數N,表示點的個數。

接下來N-1行每行三個整數x,y,z,表示點x和y之間有一條長度為z的雙向邊。

第N+1行有一個正整數M。

接下來M行每行是一個事件,事件是以下三種格式之一:

”+ x” 表示點x上出現了異象石

”- x” 表示點x上的異象石被摧毀

”?” 表示詢問使當前所有異象石所在的點連通所需的邊集的總長度最小是多少。

輸出格式

對於每個 ? 事件,輸出一個整數表示答案。

資料範圍

1≤N,M≤10^5,
1≤x,y≤N,
x≠y,
1≤z≤109

輸入樣例

6
1 2 1
1 3 5
4 1 7
4 5 3
6 4 2
10
+ 3
+ 1
?
+ 6
?
+ 5
?
- 6
- 3
?

輸出樣例

5
14
17
10

題解

思路和藍書上一樣, 按dfs序對答案進行修改, 思路就不再說了,

主要說下為什麼 按dfs這樣算 是答案的兩倍

證明(數學歸納法):
\(f(x,y)=d[x]+d[y]-2*d[lca(x,y)]\)

當 k = 1, ans=0
當 k = 2, \(ans=f(a_1,a_2)+f(a_2,a_1)\), 答案顯然是兩倍
當 k = 3, \(ans=f(a_1,a_2)+f(a_2,a_3)+f(a_3,a_1)\)
     對於n = 2少了個 \(f(a_2,a_1)\), 多了\(f(a_2,a_3)+f(a_3,a_1)\)
     即多了 \(2*d[3]+2*d[lca(a_1,a_2)]-2*d[lca(a_2,a_3)]-2*d[lca(a_1,a_3)]\)


     我們發現都乘了個 2, 這便是加入a_3之後對答案 兩倍 的影響
     由於我們是嚴格按照 dfs序 算的, 所以 a_3 對於 a_2有
     1.a_3是a_2的子節點, 那麼剛才一長串就是 a_2到a_1 和 a_2到a_3的距離的兩倍
     2.a_3和a_2在不同的兩顆子樹上\([lca(a_2,a_3) \neq a_2]\), 剛才的一長串就是 a_2到a_1 和 a_3到\(lca(a_1,a_2,a_3)\)距離的兩倍
當 k = n, \(ans=\sum_{i=1}^{n-1}f(a_1,a_i)+f(a_n,a_1)\), 比 k = n - 1 多出的部分也分為
     1.\(a_{n-1}是a_n\)的子節點, 那麼多出部分為....
     2.\(a_{n-1}和a_n\)在不同的兩顆子樹上\([lca(a_{n-1},a_n) \neq a_2]\), 多出的部分為....

就簡單證明一下, 主要不懂引理, 就不會寫(暴力超時), 具體程式碼如下

#include <bits/stdc++.h>
#define all(n) (n).begin(), (n).end()
#define se second
#define fi first
#define pb push_back
#define mp make_pair
#define sqr(n) (n)*(n)
#define rep(i,a,b) for(int i=(a);i<=(b);++i)
#define per(i,a,b) for(int i=(a);i>=(b);--i)
#define IO ios::sync_with_stdio(false);cin.tie(nullptr);cout.tie(nullptr)
using namespace std;
typedef long long ll;
typedef pair<int, int> PII;
typedef pair<ll, ll> PLL;
typedef vector<int> VI;
typedef double db;

const int N = 1e5 + 5;

int n, m, _, k;
int h[N], ne[N << 1], to[N << 1], co[N << 1], tot;
int dfn[N], cnt, d[N], f[N][30], t;
ll dist[N], ans;
char s[3];
set<PII> st;

void add(int u, int v, int c) {
    ne[++tot] = h[u]; h[u] = tot; to[tot] = v; co[tot] = c;
}

void bfs(int s) {
    queue<int> q;
    q.push(s); d[s] = 1; dist[s] = 0;
    rep (i, 0, t) f[s][i] = 0;

    while (!q.empty()) {
        int x = q.front(); q.pop();
        for (int i = h[x]; i; i = ne[i]) {
            int y = to[i];
            if (d[y]) continue;
            d[y] = d[x] + 1;
            dist[y] = dist[x] + co[i];
            f[y][0] = x;

            for (int j = 1; j <= t; ++j) 
                f[y][j] = f[f[y][j - 1]][j - 1];

            q.push(y);
        }
    }
}

int lca(int x, int y) {
    if (d[x] > d[y]) swap(x, y);
    per (i, t, 0) 
        if (d[f[y][i]] >= d[x]) y = f[y][i];

    if (x == y) return x;

    per (i, t, 0)
        if (f[x][i] != f[y][i]) x = f[x][i], y = f[y][i];

    return f[x][0]; 
}

void dfs(int u) {
    dfn[u] = ++cnt;
    for (int i = h[u]; i; i = ne[i]) {
        int y = to[i];
        if (dfn[y]) continue;
        dfs(y);
    }
}

void worka() {
    auto it = st.upper_bound({ dfn[k], k });
    int x, y;
    if (it == st.begin() || it == st.end()) x = st.begin()->se, y = st.rbegin()->se;
    else  x = it->se, y = (--it)->se;

    ans += (dist[k] << 1) - ((dist[lca(k, x)] + dist[lca(k, y)] - dist[lca(x, y)]) << 1);
    st.insert({ dfn[k], k });
}

void workb() {
    if (st.size() == 1) { st.clear(); return; }

    auto it = st.lower_bound({ dfn[k], k }), ita = it; ++ita;
    int x, y;
    if (ita == st.end()) y = st.begin()->se;
    else y = ita->se;
    if (it == st.begin()) x = st.rbegin()->se;
    else x = (--it)->se;

    ans -= (dist[k] << 1) - ((dist[lca(k, x)] + dist[lca(k, y)] - dist[lca(x, y)]) << 1);
    st.erase(--ita);
}

int main() {
    IO; cin >> n;
    rep(i, 2, n) {
        int u, v, c; cin >> u >> v >> c;
        add(u, v, c); add(v, u, c);
    }

    t = log2(n - 1) + 1;
    dfs(1); bfs(1);

    cin >> m;
    rep(i, 1, m) {
        cin >> s;
        if (s[0] == '?') cout << (ans >> 1) << '\n';
        else {
            cin >> k;
            if (st.empty()) st.insert({ dfn[k], k });
            else if (s[0] == '+') worka();
            else workb();
        }
    }
    return 0;
}