【ybtoj高效進階 21188】樹上賦值(樹形DP)
阿新 • • 發佈:2021-11-19
給你一棵樹,你要給每個點賦上 1~m 範圍的一個值,然後保證任意一條邊連著的兩個點的權值差大於等於給出的 k,然後問你有多少種賦值方法。
還是會 T。
然後你考慮 \(m\) 那麼大,中間很多部分都是相同的值。
樹上賦值
題目連結:ybtoj高效進階 21188
題目大意
給你一棵樹,你要給每個點賦上 1~m 範圍的一個值,然後保證任意一條邊連著的兩個點的權值差大於等於給出的 k,然後問你有多少種賦值方法。
思路
考慮先暴力 DP,設 \(f_{i,j}\) 為處理好 \(i\) 的子樹,然後 \(i\) 這個點選的 \(j\) 這個值的方案數。
然後就列舉每個兒子,\(f_{u,j}\) 乘上滿足的 \(l\) 的 \(f_{v,l}\) 的和。(滿足的 \(l\) 是 \(|j-l|\geqslant k\))
然後我們可以用字首和處理一個區間的 \(f_{x,i}\) 和,但是 \(O(nm)\)
然後你考慮 \(m\) 那麼大,中間很多部分都是相同的值。
那我們會發現它最大有不同的就在 \((n-1)*k\) 個。(兩半各有那麼多)
而且兩邊的還是對稱的,所以我們就可以只維護兩邊不同的那 \((n-1)*k\) 個,複雜度就是 \(O(n*nk)=O(n^2k)\) 了。
程式碼
#include<cstdio> #include<cstring> #include<iostream> #define ll long long #define mo 1000000007 using namespace std; struct node { int to, nxt; }e[201]; int T, n, m, k, le[101], KK, x, y, lim; ll f[101][20001], sum[101][20001]; void add(int x, int y) { e[++KK] = (node){y, le[x]}; le[x] = KK; } ll ksm(ll x, ll y) { ll re = 1; while (y) { if (y & 1) re = re * x % mo; x = x * x % mo; y >>= 1; } return re; } ll clac(int now, int r) {//算出字首 if (r <= lim) return sum[now][r]; if (r <= m - lim) { return (sum[now][lim] + f[now][lim] * (r - lim) % mo) % mo; } return (sum[now][lim] + f[now][lim] * (m - 2 * lim) % mo + sum[now][lim] - sum[now][m - r] + mo) % mo; } void dfs(int now, int father) { for (int i = 1; i <= lim; i++) f[now][i] = 1; for (int i = le[now]; i; i = e[i].nxt) if (e[i].to != father) { dfs(e[i].to, now); for (int j = 1; j <= lim; j++) { int L = max(0, j - k); ll val = sum[e[i].to][L]; int R = min(m, j + k - 1); val = (val + clac(e[i].to, m) - clac(e[i].to, R) + mo) % mo; f[now][j] = f[now][j] * val % mo; } } for (int i = 1; i <= lim; i++) sum[now][i] = (sum[now][i - 1] + f[now][i]) % mo; } int main() { // freopen("label.in", "r", stdin); // freopen("label.out", "w", stdout); scanf("%d", &T); while (T--) { scanf("%d %d %d", &n, &m, &k); for (int i = 1; i <= n; i++) le[i] = 0; KK = 0; for (int i = 1; i < n; i++) { scanf("%d %d", &x, &y); add(x, y); add(y, x); } if (!k) { printf("%lld\n", ksm(m, n)); continue; } lim = (n - 1) * k; if (2 * lim >= m) lim = m; dfs(1, 0); printf("%lld\n", clac(1, m)); } return 0; }