1. 程式人生 > 其它 >Codeforces Round #774 (Div. 2)D. Weight the Tree

Codeforces Round #774 (Div. 2)D. Weight the Tree

題目大意

一顆 \(n(2\leq n\leq 2\times 10^5)\) 的樹,需要為每個點賦予一個權值 \(w_{i}(1\leq w_{i}\leq10^9)\) 。一個節點稱為好節點當且僅當其相鄰的所有節點的權值和等於該節點的權值,給出一種賦值方案,使得樹中好節點的數目最多,並且所有節點的總權值和最小。

思路

考慮樹形 \(dp\) ,我們設 \(f[i,1/0]\) 為在以 \(i\) 為根的子樹中,是/否選擇了節點 \(i\) 作為好節點的情況下,最多的好節點數目以及滿足該數目所需的最小權值。顯然對於每一個好節點,令其權值為其度數,非好節點的權值為 \(1\) 顯然會是最佳方案。並且除了僅有兩個節點的情形(可以直接特判),所有的好節點都一定不相鄰,於是我們有:

\[f[i,1]=\sum_{j\in son[i]}f[j,0]+(1,deg(i)) \]\[f[i,0]=\sum_{j\in son[i]}max(f[j,0],f[j,1])+(0,1) \]

這裡的 \(max\) 我們定義為先比較好節點數,再比較權值總和,求和即為好節點數,權值和對應相加。我們在求出答案後,可以通過一遍 \(dfs\) 再根據根據轉移的情況來逆推(對於等價的轉移情況任選一種即可)求出方案,複雜度 \(O(n)\)

程式碼

#include<bits/stdc++.h>
#include<unordered_map>
#include<unordered_set>
using namespace std;
typedef long long LL;
typedef unsigned long long ULL;
typedef pair<int, int> PII;
#define all(x) x.begin(),x.end()
//#define int LL
//#define lc p*2+1
//#define rc p*2+2
#define endl '\n'
#define inf 0x3f3f3f3f
#define INF 0x3f3f3f3f3f3f3f3f
#pragma warning(disable : 4996)
#define IOS ios::sync_with_stdio(0),cin.tie(0),cout.tie(0)
const double eps = 1e-8;
const LL MOD = 1000000007;
const LL mod = 998244353;
const int maxn = 200010;

struct node {
	LL a, b;
	void operator+=(const node& rhs) { *this = node{ a + rhs.a,b + rhs.b }; }
	bool operator<(const node& rhs)
	{
		if (a == rhs.a)
			return b > rhs.b;
		return a < rhs.a;
	}
};
LL N;
vector<LL>G[maxn];
node dp[maxn][2];
LL ans[maxn];

void add_edge(LL from, LL to)
{
	G[from].push_back(to);
	G[to].push_back(from);
}

void DP(LL v, LL p)
{
	dp[v][1] = node{ 1,(LL)G[v].size() }, dp[v][0] = node{ 0,1 };
	for (int i = 0; i < G[v].size(); i++)
	{
		LL to = G[v][i];
		if (to == p)
			continue;
		DP(to, v);
		if (dp[to][1] < dp[to][0])
			dp[v][0] += dp[to][0];
		else
			dp[v][0] += dp[to][1];
		dp[v][1] += dp[to][0];
	}
}

void dfs(int v, int p, int t)
{
	ans[v] = (t ? (LL)G[v].size() : 1);
	for (int i = 0; i < G[v].size(); i++)
	{
		int to = G[v][i];
		if (to == p)
			continue;
		if (t == 1)
			dfs(to, v, 0);
		else
		{
			if (dp[to][0] < dp[to][1])
				dfs(to, v, 1);
			else
				dfs(to, v, 0);
		}
	}
}

void solve()
{
	if (N == 2)
	{
		cout << 2 << ' ' << 2 << endl;
		cout << 1 << ' ' << 1 << endl;
		return;
	}
	DP(1, 0);
	if (dp[1][0] < dp[1][1])
	{
		cout << dp[1][1].a << ' ' << dp[1][1].b << endl;
		dfs(1, 0, 1);
	}
	else
	{
		cout << dp[1][0].a << ' ' << dp[1][0].b << endl;
		dfs(1, 0, 0);
	}
	for (int i = 1; i <= N; i++)
		cout << ans[i] << ' ';
	cout << endl;
}

int main()
{
	IOS;
	cin >> N;
	int u, v;
	for (int i = 1; i < N; i++)
	{
		cin >> u >> v;
		add_edge(u, v);
	}
	solve();

	return 0;
}