1. 程式人生 > 其它 >UVA1205 Color a Tree 題解

UVA1205 Color a Tree 題解

Post time: 2020-08-05 21:23:50

傳送門

題目大意大家可以開啟題目描述中PDF看,下面開始講題解。

一、思維嘗試:

首先我們思考,為了讓總的結果最小,整個樹中權值最大的點一定在他的父親節點染色之後馬上染色。所以我們首先考慮將這兩個點合併,繼續在整個圖中找最大權值……最後只剩根節點的時候合併結束,輸出結果即可。

這個思路看起來不錯,但是有一個問題:這個“最大權值”指的是什麼呢?如果僅僅指每一個點的初始染色代價,那麼只要構造一個像這樣的圖:(節點上的值表示每個點染色代價)

(插一句:這個資料是我們機房巨佬 @sdfz171047 造出來的,真的好巧(du)妙(liu))

你會發現我們合併順序是這樣的:

10->2
5->1
4->2
2->1

這樣的話,相當於的染色順序為:

1->5->2->10->4

它的代價為:

1*1+5*2+2*3+10*4+4*5=77

然而,如果你用下面這個順序染色:

1->2->10->5->4

代價就是:

1*1+2*2+3*10+5*4+4*5=75

這樣證明了以節點初始染色代價為貪心策略是不正確的。

二、如何推貪心?

我們想要知道貪心策略,必須要找到對於圖中任意三個點的正確順序。我們設這三個點為 \(x,y,z\),並且假定 \(x,y\) 已經捆綁了,必須先染 \(x\) 再染 \(y\)

。那麼會有兩種情況:

  1. \(x\to y\to z\) 代價是 \(x+2y+3z\)
  2. \(z\to x\to y\) 代價是 \(2x+3y+z\)

比較兩式可用作差法,相減可得 \(2z-(x+y)\),除以 \(2\)\(z-\frac{x+y}{2}\)

這樣我們就找到了貪心的權值,我們可以稱它為一個點的“等效權值”,我們用 \(cnt_i\) 表示以 \(i\) 為根的子樹大小,\(sum_i\) 表示以 \(i\) 為根的子樹染色代價之和,那麼每個點的等效權值 \(W_i\) 為:

\[W_i=\frac{sum_i}{cnt_i} \]

所以我們維護一個大根堆,每次取出當前最大的 \(W_i\)

,將這個點與 \(fa_i\) 合併。

三、程式碼怎麼實現?

我的做法是維護每個點後面染哪個點 \(nxt_i\),最後從根節點開始通過這個 \(nxt\) 遍歷整個樹完成答案統計。

在結構體裡,對於每個點存這樣幾個資料:

a[i].fa //父親節點編號
a[i].val //初始染色代價
a[i].last //以i為根合併之後的最後一個染色點編號
a[i].nxt //在i號點後面染色的點
a[i].vis //i有沒有向上合併過

考慮將 \(u\) 合併到 \(f\)(注意順序):

  1. 如果 \(f\) 已經向上合併過,那麼就繼續往上找,即 f=a[f].fa
  2. \(u\) 應該在 \(f\) 的最後一個點後染色,即 a[a[f].last].nxt=u
  3. \(f\) 的最後一個點改為 \(u\) 的最後一個點,即 a[f].last=a[u].last
  4. 更新 \(W_f\),將 \(f\) 入隊。把 \(u\)vis 設為 \(1\),表示已經合併完了。

最後掃一遍即可得到結果。

點選檢視程式碼
#include<iostream>
#include<cstdio>
#include<cstring>
#include<queue>
#include<vector>
using namespace std;
const int N=1000+13;
struct Node{int val,fa,last,pos,nxt,cnt,sum;}a[N];
struct Queue{
	int v;double w;
	bool operator <(const Queue &a)const{return w<a.w;}
};
priority_queue<Queue>q;
int n,root,ans;
bool vis[N];
inline void clear(){
	ans=0;
	memset(a,0,sizeof(a));
	memset(vis,0,sizeof(vis));
	while(!q.empty()) q.pop();
}
int main(){
	while(scanf("%d%d",&n,&root)==2&&(n||root)){
		clear();
		for(int i=1;i<=n;++i){
			scanf("%d",&a[i].val);
			a[i].last=a[i].pos=i;
			a[i].sum=a[i].val,a[i].cnt=1;
			if(i!=root) q.push((Queue){i,a[i].val*1.0});
		}
		for(int i=1,u,v;i<n;++i) scanf("%d%d",&u,&v),a[v].fa=u;
		while(!q.empty()){
			int v=q.top().v,u=a[v].fa;q.pop();
			if(vis[v]) continue;vis[v]=1;
			while(vis[u]&&u!=root) u=a[u].fa;
			a[a[u].last].nxt=v,a[u].last=a[v].last;
			a[u].cnt+=a[v].cnt,a[u].sum+=a[v].sum;
			double w=a[u].sum*1.0/a[u].cnt;
			if(u!=root) q.push((Queue){u,w});
		}
		for(int i=1,u=root;i<=n;++i,u=a[u].nxt) ans+=i*a[u].val;
		printf("%d\n",ans);
	}
	return 0;
}