1. 程式人生 > 實用技巧 >暴力寫掛[CTSC2018]

暴力寫掛[CTSC2018]

題目描述

給定兩棵樹 \(T\)\(T'\)

\[\max(\mathrm{depth}(x) + \mathrm{depth}(y) - ({\mathrm{depth}(\mathrm{LCA}(x,y))}+{\mathrm{depth'}(\mathrm{LCA'}(x,y))})) \]

注:帶[ \('\) ]的表示第二棵樹

題解

注意到題目給的這個

\[\mathrm{depth}(x) + \mathrm{depth}(y) - {\mathrm{depth}(\mathrm{LCA}(x,y))}-{\mathrm{depth'}(\mathrm{LCA'}(x,y))} \]

似乎不太好算

我們把前3項轉換一下 發現上面這個式子實際上等於

\[\dfrac{1}{2}(\mathrm{depth}(x) + \mathrm{depth}(y) + \mathrm{dis}(x,y) - 2 * {\mathrm{depth'}(\mathrm{LCA'}(x,y))}) \]

這樣一來,前三項可以通過邊分治處理出來,然後最後一項則需要在第二棵樹上來計算

具體地說,我們對第一棵樹進行邊分治,然後將當前分治邊左邊的點標為黑點,右邊標為白點

假設一個點\(x\)到分治邊的距離為\(\mathrm{d}(x)\),分治邊的長度是\(v\),那麼上面式子的前3項實際上就等於\(\mathrm{depth}(x) + \mathrm{depth}(y) + (\mathrm{d}(x) + \mathrm{d}(y) + v)\)

所以把每個點的點權\(\mathrm{val}(x)\)設為\(\mathrm{depth}(x) + \mathrm{d}(x)\),然後就可以去處理第二棵樹了

在第二棵樹中列舉每個點作為lca,那麼現在目標就是找到兩個顏色不同,且在兩個不同兒子子樹裡的點使得它們的\(\mathrm{val}\)之和最大

\(f[x][0]\)表示\(x\)子樹中最大的黑點權值,\(f[x][1]\)表示最大白點權值;然後就可以在第二棵樹上進行dp來得到最大值 具體dp轉移見程式碼

但是dp一次是\(O(n)\)的 所以我們還需要在dp之前對第二棵樹建虛樹 在虛樹上dp

這樣總時間複雜度就是\(O(n\log^2 n)\)

的 依然會被卡掉。。。

如果想要\(O(n\log n)\)可以加上尤拉序+ST表求LCA以及基數排序建虛樹來強行降低複雜度 這裡我只寫了個\(O(1)\)求LCA 吸氧後勉強卡過 基數排序什麼的表示不懂

程式碼難度非常非常大 寫到心態爆炸

程式碼

#include <bits/stdc++.h>
#define NN 370005
using namespace std;
typedef long long ll;

template<typename T>
inline void read(T &num) {
	T x = 0, f = 1; char ch = getchar();
	for (; ch > '9' || ch < '0'; ch = getchar()) if (ch == '-') f = -1;
	for (; ch <= '9' && ch >= '0'; ch = getchar()) x = (x << 3) + (x << 1) + (ch ^ '0');
	num = x * f;
}

int n, q[NN], tp[NN], tot;
ll ww[NN], ans = -0x3f3f3f3f3f3f3f3f;
namespace p2{
	int head[NN], dfn[NN], pre[NN<<1], to[NN<<1], sz = 1, tme;
	ll val[NN<<1];
	
	inline void addedge(int u, int v, int w) {
		pre[++sz] = head[u]; head[u] = sz; to[sz] = v; val[sz] = w;
		pre[++sz] = head[v]; head[v] = sz; to[sz] = u; val[sz] = w;
	}
	
	int d[NN], p[1000005][21], lg2[1000005];
	int stk[NN], top;
	ll dep[NN], f[NN][2];
	bool tag[NN];
	
	void dfs(int x, int fa) {
		p[++tme][0] = x;
		dfn[x] = tme;
		for (int i = head[x]; i; i = pre[i]) {
			int y = to[i];
			if (y == fa) continue;
			d[y] = d[x] + 1;
			dep[y] = dep[x] + val[i];
			dfs(y, x);
			p[++tme][0] = x;
		}
	}
	
	inline int LCA(int x, int y) {
		if (dfn[x] > dfn[y]) swap(x, y);
		int l = dfn[x], r = dfn[y], len = dfn[y] - dfn[x] + 1;
		if (d[p[l][lg2[len]]] < d[p[r-(1<<lg2[len])+1][lg2[len]]]) {
			return p[l][lg2[len]];
		} else return p[r-(1<<lg2[len])+1][lg2[len]];
	}
	
	void init() {
		dfs(1, 0);
		for (int i = 2; i <= tme; i++) lg2[i] = lg2[i>>1] + 1;
		for (int l = 1; (1 << l) <= tme; l++) {
			for (int i = 1; i <= tme; i++) {
				if (d[p[i][l-1]] < d[p[i+(1<<(l-1))][l-1]]) {
					p[i][l] = p[i][l-1];
				} else p[i][l] = p[i+(1<<(l-1))][l-1];
			} 
		}
		memset(head, 0, sizeof(head));
		sz = 1;
	}
	
	bool cmp(int x, int y) {
		return dfn[x] < dfn[y];
	}
	
	void buildtree() {
		sz = 1;
		sort(q + 1, q + tot + 1, cmp);
		for (int i = 1; i <= tot; i++) tag[q[i]] = 1;
		stk[top=1] = 1;
		for (int i = 1; i <= tot; i++) {
			if (q[i] == 1) continue;
			if (top == 1) {
				stk[++top] = q[i]; 
				continue;
			}
			int lca = LCA(stk[top], q[i]);
			while (top > 1 && dfn[stk[top-1]] >= dfn[lca]) {
				addedge(stk[top], stk[top-1], 0);
				top--;
			}
			if (lca != stk[top]) {
				addedge(stk[top], lca, 0);
				stk[top] = lca;
			}
			stk[++top] = q[i];
		}
		while (top > 1) {
			addedge(stk[top], stk[top-1], 0);
			top--;
		}
	}
	
	void dp(int x, int fa, ll len) {
		f[x][0] = f[x][1] = -0x3f3f3f3f3f3f3f3f;
		if (tag[x]) f[x][tp[x]] = ww[x];
		for (int i = head[x]; i; i = pre[i]) {
			int y = to[i];
			if (y == fa) continue;
			dp(y, x, len);
			ll now = max(f[x][0] + f[y][1], f[x][1] + f[y][0]);
			ans = max(ans, len + now - 2 * dep[x]);
			f[x][0] = max(f[x][0], f[y][0]);
			f[x][1] = max(f[x][1], f[y][1]);
		} 
		tag[x] = 0; head[x] = 0;
	}
	
	void solve(ll len) {
		buildtree();
		dp(1, 0, len);
	}
}

namespace p1{
	int head[NN<<2], pre[NN<<3], to[NN<<3], sz = 1, N;
	ll val[NN<<3];
	vector<pair<int, ll> > son[NN<<2];
	bool vis[NN<<2];
	int siz[NN<<2], ct, mn, sum;
	ll dep[NN<<2];
	
	inline void addedge(int u, int v, ll w) {
		pre[++sz] = head[u]; head[u] = sz; to[sz] = v; val[sz] = w;
		pre[++sz] = head[v]; head[v] = sz; to[sz] = u; val[sz] = w;
	}
	
	void dfs1(int x, int fa) {
		for (int i = head[x]; i; i = pre[i]) {
			int y = to[i];
			if (y == fa) continue;
			son[x].push_back(make_pair(y, val[i]));
			dep[y] = dep[x] + val[i];
			dfs1(y, x);
		}
	}
	
	void rebuild() {
		memset(head, 0, sizeof(head)); sz = 1;
		for (int i = 1; i <= N; i++) {
			int k = son[i].size();
			if (k <= 2) {
				for (int j = 0; j < k; j++) {
					addedge(i, son[i][j].first, son[i][j].second);
				}
			} else {
				addedge(i, ++N, 0); addedge(i, ++N, 0);
				for (int j = 0; j < k; j++) {
					if (j & 1) son[N-1].push_back(son[i][j]);
					else son[N].push_back(son[i][j]);
				}
			}
		}
	}
	
	void findct(int x, int fa) {
		siz[x] = 1;
		for (int i = head[x]; i; i = pre[i]) {
			int y = to[i];
			if (y == fa || vis[i>>1]) continue;
			findct(y, x);
			siz[x] += siz[y];
			int now = max(siz[y], sum - siz[y]);
			if (now < mn) {
				mn = now;
				ct = i;
			}
		}
	}
	
	void dfs(int x, int fa, ll dis, int o) {
		if (x <= n) {
			q[++tot] = x;
			ww[x] = dep[x] + dis;
			tp[x] = o;
		}
		for (int i = head[x]; i; i = pre[i]) {
			int y = to[i];
			if (y == fa || vis[i>>1]) continue;
			dfs(y, x, dis + val[i], o);
		}
	}
	
	void divide(int x, int _siz) {
		ct = 0; mn = 0x7fffffff; 
		sum = _siz;
		findct(x, 0);
		if (!ct) return;
		vis[ct>>1] = 1;
		int l = to[ct], r = to[ct^1];
		tot = 0;
		dfs(l, 0, 0, 0); dfs(r, 0, 0, 1);
		if (!tot) return;
		p2::solve(val[ct]);
		divide(l, siz[to[ct]]); divide(r, _siz - siz[to[ct]]);
	}
}


int main() {
	read(n); 
	p1::N = n;
	for (int i = 1, u, v, w; i < n; i++) {
		read(u); read(v); read(w);
		p1::addedge(u, v, w);
	}	
	for (int i = 1, u, v, w; i < n; i++) {
		read(u); read(v); read(w);
		p2::addedge(u, v, w);
	}
	p1::dfs1(1, 0);
	p1::rebuild();
	p2::init();
	p1::divide(1, p1::N);
	ans >>= 1;
	for (int i = 1; i <= n; i++) {
		ans = max(ans, p1::dep[i] - p2::dep[i]);
	}
	printf("%lld\n", ans);
	return 0;
}