1. 程式人生 > >[BZOJ4753][Jsoi2016]最佳團體(分數規劃+樹形DP)

[BZOJ4753][Jsoi2016]最佳團體(分數規劃+樹形DP)

Address

洛谷P4322
BZOJ4753
LOJ#2071

Solution

看到最大化分式的值,考慮分數規劃,二分答案 m i d mid
判定答案是否能夠大於 m

i d mid 也就是判斷是否存在一個以 0 0 為根的大小為 K +
1 K+1
的連通子樹(假設 P 0 = S 0
= 0 P_0=S_0=0
)滿足:
P S > m i d \frac{\sum P}{\sum S}>mid
把分母去掉並移項:
P m i d × S > 0 \sum P-mid\times \sum S>0
於是我們給每一個點一個新的權值 v a l i = P i m i d × S i val_i=P_i-mid\times S_i
問題轉化成樹上以 0 0 為根的大小為 K + 1 K+1 的權值和最大的連通子樹。
顯然樹形dp
f [ u ] [ i ] f[u][i] 表示 u u 為根選出 i i 個點的最大權值和連通子樹。其中 f [ u ] [ 0 ] f[u][0] 表示什麼都不選(對於任意的 u u f [ u ] [ 0 ] = 0 f[u][0]=0
此外, f [ u ] [ 1 ] = v a l u f[u][1]=val_u
這是一個樹形依賴揹包問題。
設當前列舉到 u u 的子樹 v v ,並且 f [ u ] [ i ] f'[u][i] 表示 u u 的子樹內(不包括 v v v v 之後的子樹),現在要把 f [ v ] f[v] 合併起來,計算 f [ u ] [ ] f[u][] 表示依次到子樹 v v 的答案:
f [ u ] [ i + j ] = max ( f [ u ] [ i + j ] , f [ u ] [ i ] + f [ v ] [ j ] ) f[u][i+j]=\max(f[u][i+j],f[u][i]+f[v][j])
一次 dp 複雜度看上去是 O ( n 3 ) O(n^3) 的,但是第二維的上界只有 u u 的子樹大小,相當於每對點都只在 lca 處計算貢獻了一次,所以一次 dp 複雜度 O ( n 2 ) O(n^2)
總複雜度 O ( n 2 log ) O(n^2\log)

Code

#include <cmath>
#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>
#define For(i, a, b) for (i = a; i <= b; i++)
#define Tree(u) for (int e = adj[u], v = go[e]; e; e = nxt[e], v = go[e])
using namespace std;

inline int read()
{
	int res = 0; bool bo = 0; char c;
	while (((c = getchar()) < '0' || c > '9') && c != '-');
	if (c == '-') bo = 1; else res = c - 48;
	while ((c = getchar()) >= '0' && c <= '9')
		res = (res << 3) + (res << 1) + (c - 48);
	return bo ? ~res + 1 : res;
}

const int N = 2515;
const double eps = 1e-4;
int k, n, S[N], P[N], ecnt, nxt[N], adj[N], go[N], sze[N];
double f[N][N], val[N], x[N], y[N];

template <class T>
T Min(T a, T b) {return a < b ? a : b;}

template <class T>
T Max(T a, T b) {return a > b ? a : b;}

void add_edge(int u, int v)
{
	nxt[++ecnt] = adj[u]; adj[u] = ecnt; go[ecnt] = v;
}

void dfs(int u)
{
	int i, j;
	f[u][0] = 0; f[u][1] = val[u];
	For (i, 2, k + 1) f[u][i] = -1e20;
	sze[u] = 1;
	Tree(u)
	{
		dfs(v);
		int le = Min(k + 1, sze[u]), ri = Min(k + 1, sze[v]);
		For (i, 0, le) x[i] = f[u][i];
		For (i, 0, ri) y[i] = f[v][i];
		For (i, 1, le) For (j, 0, ri) if (i + j <= k + 1)
			f[u][i + j] = Max(f[u][i + j], x[i] + y[j]);
		sze[u] += sze[v];
	}
}

int main()
{
	int i, x;
	k = read(); n = read();
	For (i, 1, n) S[i] = read(), P[i] = read(),
		x = read(), add_edge(x, i);
	double l = 0, r = 1e4;
	while (r - l >= eps)
	{
		double mid = (l + r) / 2;
		For (i, 1, n) val[i] = 1.0 * P[i] - mid * S[i];
		if (dfs(0), f[0][k + 1] > 0) l = mid;
		else r = mid;
	}
	printf("%.3lf\n", l);
	return 0;
}