1. 程式人生 > 實用技巧 >淺談樹形DP

淺談樹形DP

淺談樹形DP

本篇隨筆簡單講解一下DP中的樹形DP問題。


樹形DP的概念

樹形DP在本蒟蒻心目中的印象就是轉移過程中,某個節點維護的資訊是由其子節點給出的。換句話說,樹形DP就是在樹上跑DP,將需要維護的答案資訊一層一層地傳遞到根節點,然後得出整個問題的答案。

實際上,我更喜歡把圖理解為一些元素加一些關係。同樣地,樹也是一些關係和一些元素的結合。所以樹形DP只是把轉移過程中的資訊來源加了個限定條件:這個條件就是由樹的邊,也就是父子關係來限定的。所以,樹形DP其實與普通DP的思想是大同小異的。

樹形DP的實現

根據我們剛剛給出的定義,樹形DP中每個節點維護的資訊都是由子節點給出的。也就是說,我們需要從各個葉子開始向根節點層層更新,也就是一種回溯

。那麼我們自然而然地聯想到使用深搜解決這個問題。

具體的實現,我們用一道例題:洛谷P1352 沒有上司的舞會 來講解


題目連結:題目傳送門

題解連結:題解傳送門

題目描述

某大學有 nn 個職員,編號為 1\ldots n1…n

他們之間有從屬關係,也就是說他們的關係就像一棵以校長為根的樹,父結點就是子結點的直接上司。

現在有個週年慶宴會,宴會每邀請來一個職員都會增加一定的快樂指數 r_ir**i,但是呢,如果某個職員的直接上司來參加舞會了,那麼這個職員就無論如何也不肯來參加舞會了。

所以,請你程式設計計算,邀請哪些職員可以使快樂指數最大,求最大的快樂指數。

輸入格式

輸入的第一行是一個整數 nn

第 22 到第 (n + 1)(n+1) 行,每行一個整數,第 (i+1)(i+1) 行的整數表示 ii 號職員的快樂指數 r_ir**i

第 (n + 2)(n+2) 到第 (2n + 1)(2n+1) 行,每行輸入一對整數 l, kl,k,代表 kk 是 ll 的直接上司。

輸出格式

輸出一行一個整數代表最大的快樂指數。


看到這個題是個多階段決策的問題,然後還有樹形結構,那麼就是樹形DP。(逃

那麼考慮狀態和轉移。

第一維肯定是以\(i\)為根的子樹。我們容易發現這個狀態肯定與當前節點有沒有選擇有關。因為選了當前節點,他的兒子們就都選不了。所以第二維就設定成選不選當前節點。

那麼我們的狀態就是:\(dp[i][0/1]\)

表示以\(i\)為根的子樹不邀請/邀請\(i\)得到的最大快樂指數。

那麼狀態轉移方程就是:

\[dp[x][0]+=\max(dp[y][0],dp[y][1])\qquad(y\in son[x]) \\ \quad \\dp[x][1]+=dp[y][0] \]

方程很好想。也很容易理解。

那麼就是轉移。也就是這道例題著重講解的地方。

其實我們說樹形DP就是從葉子往根節點轉移,換句話說,就是從下到上統計資訊。那麼它就與我們做過的其他樹上統計資訊的一樣。比如樹的重心的找法。比如樹鏈剖分的預處理部分,等等。

那麼就回歸到了深搜上面,只需要把統計的資訊換成\(dp[i][j]\)陣列即可。

程式碼:

#include<cstdio>
#include<algorithm>
using namespace std;
const int maxn=6010;
int n,root,ans;
int a[maxn],fa[maxn];
int tot,to[maxn<<1],nxt[maxn<<1],head[maxn];
int dp[maxn][2];//dp[i][0/1]表示以i為根的子樹邀請/不邀請i的最大快樂指數。
void add(int x,int y)
{
    to[++tot]=y;
    nxt[tot]=head[x];
    head[x]=tot;
}
void dfs(int x,int f)
{
    for(int i=head[x];i;i=nxt[i])
    {
        int y=to[i];
        if(y==fa[x])
            continue;
        dfs(y,x);
        dp[x][0]+=max(dp[y][0],dp[y][1]);
        dp[x][1]+=dp[y][0];
    }
}
int main()
{
    scanf("%d",&n);
    for(int i=1;i<=n;i++)
        scanf("%d",&a[i]);
    for(int i=1;i<n;i++)
    {
        int x,y;
        scanf("%d%d",&x,&y);
        add(x,y);
        add(y,x);
        fa[x]=y;
    }
    for(int i=1;i<=n;i++)
    {
        if(!fa[i])
            root=i;
        dp[i][1]=a[i];
    }
    dfs(root,0);
    ans=max(dp[root][0],dp[root][1]);
    printf("%d",ans);
    return 0;
}