1. 程式人生 > 其它 >專題3 - 樹狀dp

專題3 - 樹狀dp

樹狀dp,顧名思義,給出一棵樹並對其賦予權值(或邊權或點權),在樹上進行動態規劃。

其基本思路是從根節點開始進行記憶化搜尋,或是從葉子節點開始自下而上正向推進。

NC22598

題目中要求所有度為1的點都不能到達關鍵點\(S\),那麼問題可以轉述成“對於一棵以\(S\)點位根節點的樹,所有的葉子節點都不能到達根\(S\),這邊給出一個棵樹:

在這張圖中,假設關鍵點為\(1\),那麼我需要保證\(2,4,5\)這三個點無法到達\(1\),對於\(4,5\)兩個點而言,我可以選擇刪去\(1->3\)這條邊,也可以分別刪去\(3->4\)和\(3->5\)兩條邊。

由此可以得出狀態轉移方程\(f[i]=max\{f[j]\}\),如果為葉子節點,返回該節點與其父節點的邊權。

有關存圖與搜尋的方式可以分為以下兩種。

(1) 如果父子關係明確,可以採用有向圖的方式存鄰接表。

(2) 對於任意樹,都可以採用無向圖的方式存鄰接表,在搜尋時記錄該節點的父節點防止走回頭路。

#include <bits/stdc++.h>
#define fast ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define ll long long
#define pb push_back
using namespace std;
const int maxn = 1e5 + 10;
const ll INF = 1e18;
int
n, m, s; vector<int> v[maxn]; map<pair<int, int>, ll> mp; ll f[maxn]; ll dfs(int x, int fa) { if (v[x].size() == 1 && x != s) return mp[make_pair(x, fa)]; for (int i = 0; i < v[x].size(); i++) { if (v[x][i] == fa) continue; f[x]
+= dfs(v[x][i], x); } if (fa == 0) return f[x]; return f[x] = min(f[x], mp[make_pair(x, fa)]); } int main() { fast; cin >> n >> m >> s; int uu, vv, w; for (int i = 1; i <= m; i++) { cin >> uu >> vv >> w; v[uu].pb(vv); v[vv].pb(uu); mp[make_pair(uu, vv)] = w; mp[make_pair(vv, uu)] = w; } for (int i = 1; i <= n; i++) { f[i] = 0; } cout << dfs(s, 0) << '\n'; }

NC202475

題目相當的直接,找出權值最大的子鏈。

最關鍵的地方在於子鏈的定義。

仍然是這一棵樹,\(1->3->5\)固然是一條子鏈,然而如同\(4->3->5\),\(2->1->3->5\)這類也算作子鏈。在狀態轉移時,我們轉移以該節點為根的情況下子鏈權值的最大值,對於每個節點,找到權值最大的兩條子鏈並將權值相加。

#include <bits/stdc++.h>
#define fast ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define ll long long
#define pb push_back
using namespace std;
const int maxn = 1e5 + 10;
int n;
ll a[maxn];
ll dp[maxn] = {0};
ll maxx = -1e9;
vector<int> v[maxn];
void dfs(int x, int pre)
{
    dp[x] = a[x];
    for (int i = 0; i < v[x].size(); i++)
    {
        if (v[x][i] == pre)
            continue;
        dfs(v[x][i], x);
        maxx = max(maxx, dp[x] + dp[v[x][i]]);
        dp[x] = max(dp[x], dp[v[x][i]] + a[x]);
    }
    maxx = max(maxx, dp[x]);
}
int main()
{
    cin >> n;
    for (int i = 1; i <= n; i++)
    {
        cin >> a[i];
    }
    for (int i = 1; i < n; i++)
    {
        int p, q;
        cin >> p >> q;
        v[p].pb(q);
        v[q].pb(p);
    }
    dfs(1, 0);
    cout << maxx << '\n';
}

NC15033

再把這個圖複製下來:

對於節點\(3\),它的子樹包括\(\{4\}\),\(\{5\}\),\(\{1,2\}\),我們可以從任意點作為根節點進行搜尋,在這棵根確定的樹中,對於任意一點\(x\),不僅需要考慮子樹,還需要向上考慮剩餘的節點數量(因為當\(x\)作為根節點是,這些剩下的點也將成為一棵子樹)。

在狀態轉移時,用\(f[x]\)表示以該節點為根的子樹的節點數,那麼(\n-f[x]\)就是剩餘部分的節點數,找到其中的最小值即可(節點數越小,越平衡)。

#include <bits/stdc++.h>
#define ll long long
#define fast ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
#define pb push_back
using namespace std;
const int maxn = 1010;
vector<int> v[maxn];
int dp[maxn];
int ansx = 1e9, ansnum = 1e9;
int n;
void dfs(int x, int pre)
{
    dp[x] = 1;
    int m = 0;
    for (int i = 0; i < v[x].size(); i++)
    {
        if (v[x][i] == pre)
            continue;
        dfs(v[x][i], x);
        dp[x] += dp[v[x][i]];
        m = max(m, dp[v[x][i]]);
    }
    m = max(m, n - dp[x]);
    if (m < ansnum)
    {
        ansnum = m;
        ansx = x;
    }
    else if (m == ansnum)
    {
        if (ansx > x)
            ansx = x;
    }
}
int main()
{
    fast;
    cin >> n;
    int p, q;
    for (int i = 1; i < n; i++)
    {
        cin >> p >> q;
        v[p].pb(q);
        v[q].pb(p);
    }
    dfs(1, 0);
    cout << ansx << ' ' << ansnum << '\n';
}

NC51178

狀態轉移思路差別不大,只是對於一個節點,需要和他孩子的孩子建立關係。因為上面的所有題目都是父子關係,所以不需要進行記憶化搜尋,但這裡對於每個點,都有可能被搜尋到兩次,所以要用到記憶化搜尋。

#include <bits/stdc++.h>
#define ll long long
#define pb push_back
#define fast ios::sync_with_stdio(0), cin.tie(0), cout.tie(0)
using namespace std;
const int maxn = 6010;
ll h[maxn];
int vis[maxn];
ll f[maxn];
vector<int> mp[maxn];
ll dfs(int x)
{
    if (f[x])
        return f[x];
    ll tmp1 = 0;
    ll tmp2 = 0;
    for (int i = 0; i < mp[x].size(); i++)
    {
        int nex = mp[x][i];
        tmp2 += dfs(mp[x][i]);
        for (int j = 0; j < mp[nex].size(); j++)
        {
            tmp1 += dfs(mp[nex][j]);
        }
    }
    return f[x] = max(tmp1 + h[x], tmp2);
}
int main()
{
    fast;
    int n;
    cin >> n;
    for (int i = 1; i <= n; i++)
    {
        cin >> h[i];
    }
    int u, v;
    for (int i = 1; i < n; i++)
    {
        cin >> u >> v;
        vis[u]++;
        mp[v].pb(u);
    }
    cin >> u >> v;
    for (int i = 1; i <= n; i++)
    {
        f[i] = 0;
    }
    for (int i = 1; i <= n; i++)
    {
        if (!vis[i])
            dfs(i), cout << f[i] << '\n';
    }
}