1. 程式人生 > 其它 >【ybtoj高效進階 21188】樹上賦值(樹形DP)

【ybtoj高效進階 21188】樹上賦值(樹形DP)

給你一棵樹,你要給每個點賦上 1~m 範圍的一個值,然後保證任意一條邊連著的兩個點的權值差大於等於給出的 k,然後問你有多少種賦值方法。

樹上賦值

題目連結:ybtoj高效進階 21188

題目大意

給你一棵樹,你要給每個點賦上 1~m 範圍的一個值,然後保證任意一條邊連著的兩個點的權值差大於等於給出的 k,然後問你有多少種賦值方法。

思路

考慮先暴力 DP,設 \(f_{i,j}\) 為處理好 \(i\) 的子樹,然後 \(i\) 這個點選的 \(j\) 這個值的方案數。

然後就列舉每個兒子,\(f_{u,j}\) 乘上滿足的 \(l\)\(f_{v,l}\) 的和。(滿足的 \(l\)\(|j-l|\geqslant k\)

然後我們可以用字首和處理一個區間的 \(f_{x,i}\) 和,但是 \(O(nm)\)

還是會 T。
然後你考慮 \(m\) 那麼大,中間很多部分都是相同的值。

那我們會發現它最大有不同的就在 \((n-1)*k\) 個。(兩半各有那麼多)
而且兩邊的還是對稱的,所以我們就可以只維護兩邊不同的那 \((n-1)*k\) 個,複雜度就是 \(O(n*nk)=O(n^2k)\) 了。

程式碼

#include<cstdio>
#include<cstring>
#include<iostream>
#define ll long long
#define mo 1000000007

using namespace std;

struct node {
	int to, nxt;
}e[201];
int T, n, m, k, le[101], KK, x, y, lim;
ll f[101][20001], sum[101][20001];

void add(int x, int y) {
	e[++KK] = (node){y, le[x]}; le[x] = KK;
}

ll ksm(ll x, ll y) {
	ll re = 1;
	while (y) {
		if (y & 1) re = re * x % mo;
		x = x * x % mo;
		y >>= 1;
	}
	return re;
}

ll clac(int now, int r) {//算出字首
	if (r <= lim) return sum[now][r];
	if (r <= m - lim) {
		return (sum[now][lim] + f[now][lim] * (r - lim) % mo) % mo;
	}
	return (sum[now][lim] + f[now][lim] * (m - 2 * lim) % mo + sum[now][lim] - sum[now][m - r] + mo) % mo;
}

void dfs(int now, int father) {
	for (int i = 1; i <= lim; i++)
		f[now][i] = 1;
	for (int i = le[now]; i; i = e[i].nxt)
		if (e[i].to != father) {
			dfs(e[i].to, now);
			for (int j = 1; j <= lim; j++) {
				int L = max(0, j - k);
				ll val = sum[e[i].to][L];
				int R = min(m, j + k - 1);
				val = (val + clac(e[i].to, m) - clac(e[i].to, R) + mo) % mo; 
				f[now][j] = f[now][j] * val % mo;
			}
		}
	for (int i = 1; i <= lim; i++)
		sum[now][i] = (sum[now][i - 1] + f[now][i]) % mo;
}

int main() {
//	freopen("label.in", "r", stdin);
//	freopen("label.out", "w", stdout);
	
	scanf("%d", &T);
	while (T--) {
		scanf("%d %d %d", &n, &m, &k);
		for (int i = 1; i <= n; i++) le[i] = 0; KK = 0;
		for (int i = 1; i < n; i++) {
			scanf("%d %d", &x, &y); add(x, y); add(y, x);
		}
		
		if (!k) {
			printf("%lld\n", ksm(m, n));
			continue;
		}
		
		lim = (n - 1) * k;
		if (2 * lim >= m) lim = m;
		
		dfs(1, 0);
		
		printf("%lld\n", clac(1, m));
	}
	
	return 0;
}