【BZOJ3281】小P的煩惱
阿新 • • 發佈:2018-11-08
【題目連結】
【思路要點】
- 為每一條原圖中的邊新建一個點,建出 的以 為根的支配樹, 到 路徑上每一個代表邊的點就是每一條必經的邊。
- 求出相鄰的兩條邊之間的最短路,用 解決剩餘問題即可。
- 時間複雜度 。
【程式碼】
#include<bits/stdc++.h> using namespace std; const int MAXN = 3e5 + 5; const int MAXLOG = 20; typedef long long ll; typedef long double ld; typedef unsigned long long ull; template <typename T> void chkmax(T &x, T y) {x = max(x, y); } template <typename T> void chkmin(T &x, T y) {x = min(x, y); } template <typename T> void read(T &x) { x = 0; int f = 1; char c = getchar(); for (; !isdigit(c); c = getchar()) if (c == '-') f = -f; for (; isdigit(c); c = getchar()) x = x * 10 + c - '0'; x *= f; } template <typename T> void write(T x) { if (x < 0) x = -x, putchar('-'); if (x > 9) write(x / 10); putchar(x % 10 + '0'); } template <typename T> void writeln(T x) { write(x); puts(""); } int n, m, s, t, rope, tot, cnt; int val[MAXN], d[MAXN], q[MAXN], len[MAXN]; int father[MAXN][MAXLOG], dp[MAXN], depth[MAXN]; bool vis[MAXN]; vector <int> a[MAXN], b[MAXN]; int solve(int *a) { int ans = 0; static int b[MAXN], suf[MAXN]; memset(suf, 0, sizeof(suf)); for (int i = 1; i <= cnt + 1; i++) { if (i & 1) b[i] = a[i]; else b[i] = 0; } int l = cnt, r = cnt, sum = 0, now = 0; while (l >= 0) { if (sum + a[l] <= 2 * rope) { sum += a[l]; now += b[l]; l--; } else { sum -= a[r]; now -= b[r]; r--; } chkmax(ans, now + min(2 * rope - sum, b[l] + b[r + 1])); } l = cnt, r = cnt, sum = 0, now = 0; while (l >= 0) { if (sum + a[l] <= rope) { sum += a[l]; now += b[l]; l--; } else { sum -= a[r]; now -= b[r]; r--; } chkmax(suf[l], now + min(rope - sum, b[l] + b[r + 1])); chkmax(suf[l + 1], now + min(rope - sum, b[r + 1])); } for (int i = cnt; i >= 1; i--) chkmax(suf[i], suf[i + 1]); l = 0, r = 0, sum = 0, now = 0; while (r < cnt) { if (sum + a[r + 1] <= rope) { sum += a[r + 1]; now += b[r + 1]; r++; } else { sum -= a[l + 1]; now -= b[l + 1]; l++; } chkmax(ans, now + min(rope - sum, b[l] + b[r + 1]) + suf[r + 2]); chkmax(ans, now + min(rope - sum, b[l]) + suf[r + 1]); } return ans; } int solve() { static int a[MAXN], b[MAXN]; for (int i = 1; i <= cnt; i++) { a[i] = len[i]; b[i] = len[cnt - i + 1]; } return max(solve(a), solve(b)); } void init() { memset(d, 0, sizeof(d)); memset(dp, 0, sizeof(dp)); memset(val, 0, sizeof(val)); memset(vis, false, sizeof(vis)); memset(depth, 0, sizeof(depth)); memset(father, 0, sizeof(father)); for (int i = 1; i <= n + m; i++) { a[i].clear(); b[i].clear(); } } void visit(int pos) { vis[pos] = true; for (unsigned i = 0; i < a[pos].size(); i++) if (!vis[a[pos][i]]) visit(a[pos][i]); } int lca(int x, int y) { if (depth[x] < depth[y]) swap(x, y); for (int i = MAXLOG - 1; i >= 0; i--) if (depth[father[x][i]] >= depth[y]) x = father[x][i]; if (x == y) return x; for (int i = MAXLOG - 1; i >= 0; i--) if (father[x][i] != father[y][i]) { x = father[x][i]; y = father[y][i]; } return father[x][0]; } int main() { int T; read(T); while (T--) { read(n), read(m); read(s), read(t); read(rope), s++, t++; init(); for (int i = 1; i <= m; i++) { int x, y, num = i + n; read(x), read(y), read(val[num]), x++, y++; a[x].push_back(num); b[num].push_back(x); a[num].push_back(y); b[y].push_back(num); } visit(s), tot = 0; if (!vis[t]) { printf("-1\n"); continue; } for (int i = 1; i <= n + m; i++) { if (!vis[i]) continue; for (unsigned j = 0; j < b[i].size(); j++) if (vis[b[i][j]]) d[i]++; if (d[i] == 0) q[++tot] = i; } for (int i = 1; i <= tot; i++) { int tmp = q[i]; for (unsigned j = 0; j < a[tmp].size(); j++) if (vis[a[tmp][j]]) { if (--d[a[tmp][j]] == 0) q[++tot] = a[tmp][j]; } } depth[s] = 1; for (int i = 2; i <= tot; i++) { int tmp = q[i], ans = -1; for (unsigned j = 0; j < b[tmp].size(); j++) if (vis[b[tmp][j]]) { if (ans == -1) ans = b[tmp][j]; else ans = lca(ans, b[tmp][j]); } depth[tmp] = depth[ans] + 1; father[tmp][0] = ans; for (int j = 1; j < MAXLOG; j++) father[tmp][j] = father[father[tmp][j - 1]][j - 1]; } int pos = t; cnt = 0; for (int i = tot; i >= 1; i--) { int tmp = q[i]; dp[tmp] = 1e9; for (unsigned j = 0; j < a[tmp].size(); j++) chkmin(dp[tmp], dp[a[tmp][j]]); if (dp[tmp] == 1e9) dp[tmp] = 0; if (tmp == pos) { if (pos > n) { if (cnt == 0) len[++cnt] = val[tmp]; else { len[++cnt] = dp[tmp]; len[++cnt] = val[tmp]; } dp[tmp] = 0; } else dp[tmp] += val[tmp]; pos = father[pos][0]; } else dp[tmp] += val[tmp]; } int totlen = 0; for (int i = 1; i <= cnt; i += 2) totlen += len[i]; printf("%d\n", totlen - solve()); } return 0; }