1. 程式人生 > 實用技巧 >啟智樹衝省隊組Day5T3 劃分

啟智樹衝省隊組Day5T3 劃分

首先可以把題意轉化一下:

把一棵點帶權樹上的 \(n\) 個點劃分為 \(A\)\(B\) 兩個集合,一種劃分方案的代價是:來自 \(A\) 集合的代價:\(\sum_{i}d_i+g(A)-\sum_{x \gets y}[w_x \le w_y]\);來自 \(B\) 集合的代價:\(\sum_{x \gets y}[w_x < w_y]\)。其中 \(x \gets y\) 表示 \(x\)\(y\) 的祖先, \(g(A)={|A| \choose 2}\)\(d_i\) 表示 \(i\) 到根的距離。總代價為兩集合的代價之和。

全部在 \(B\) 集合的情況非常好算,考慮從 \(B\)

中逐個取點加入到 \(A\)

如果點權互不相同,則容易發現取每個點的代價為 \(d_i - a_i - b_i\),其中 \(a_i\) 表示比 \(i\) 小的祖先個數, \(b_i\) 表示比 \(i\) 大的子孫個數。它們與 \(A,B\) 集合的狀態無關,所以從小到大排序後一個個加入即可。

如果有點權相同的情況,則代價應該是 \(d_i - a_i - b_i - c_i\),其中 \(c_i\) 表示目前 \(A\) 集合中與 \(i\) 有祖先子孫關係的點中與 \(i\) 權值相同的點的個數。這時候我們不知道 \(c_i\) 是多少,並且不知道取走 \(d_i - a_i - b_i - c_i\)

最小的點的後續轉移是否更優。

不過看起來如果存在祖先子孫權值相同,先取祖先好像更優。因為 \(d_x - a_x \le d_y - a_y\)(儘管可能 \(a_x \le a_y\),但是 \(d_x > d_y\) 有壓倒性優勢),並且一定有 \(-b_x - c_x \le -b_y - c_y\),於是當前祖先一定比子孫更優。並且選擇祖先的後續轉移也比子孫的好,因為選擇祖先能“惠及”更多的點有 \(-c_i\) 的代價。因此,選擇一個點的時候與其相等祖先一定全部被選,與其相等的子孫一定全沒被選。據此,\(c_i\) 還可以表示為與 \(i\) 相等的祖先的個數。

每個點的 \(a_i,b_i,c_i,d_i\)

可以用樹狀陣列算出。

關鍵程式碼:

int fa[N], dep[N];
int a[N], b[N];
void dfs1(int cur, int faa) {
	fa[cur] = faa; dep[cur] = dep[faa] + 1;
	a[cur] = query(w[cur]);
	add(w[cur], 1);
	for (register int i = head[cur]; i; i = e[i].nxt) {
		int to = e[i].to; if (to == faa)	continue;
		dfs1(to, cur);
	}
	add(w[cur], -1);
}
void dfs2(int cur, int faa) {
	add(w[cur], 1);
	b[cur] -= (query(ltot) - query(w[cur]));
	for (register int i = head[cur]; i; i = e[i].nxt) {
		int to = e[i].to; if (to == faa)	continue;
		dfs2(to, cur);
	}
	b[cur] += (query(ltot) - query(w[cur]));
}

int id[N];
inline bool cmp(const int x, const int y) {
	return dep[x] - a[x] - b[x] < dep[y] - a[y] - b[y];
}

ll ans[N];
int main() {
	read(n);
	for (register int i =1 ; i <= n; ++i)	read(w[i]), h[i] = w[i];
	lsh();
	for (register int i =1 ; i < n; ++i) {
		int u, v; read(u), read(v);
		addedge(u, v), addedge(v, u);
	}
	dep[0] = -1;
	dfs1(1, 0);
	dfs2(1, 0);
	for (register int i = 1; i <= n; ++i)	id[i] = i;
	sort(id +1 , id + 1 + n, cmp);
	ll res = 0;
	for (register int i = 1; i <= n; ++i)	res += b[i];
	ans[n] = res;
	for (register int i = 1; i <= n; ++i) {
		int p = id[i];
		res += dep[p] - a[p] - b[p];
		ans[n - i] = res + ((1ll * i * (i - 1)) >> 1);
	}
	for (register int i = 0; i <= n; ++i)
		printf("%lld\n", ans[i]);
	return 0;
}