P4362 [NOI2002]貪吃的九頭龍 題解
Post time: 2020-07-22 17:49:15
這個題顯然是一個樹形dp。我們先來總結一下樹形dp的套路:
-
定義 dp 陣列意義,注意要考慮到題目的一些特殊要求(比如本題的大頭),還要考慮到如何輸出結果。
-
思考如何將一個點的所有 dp 值由這個點的兒子節點轉移過來,即我們常說的狀態轉移方程。
-
將方程放到 dfs 深搜中更新 dp 值,最終輸出答案。
所以,我們先來考慮這個題怎麼樣定義 dp 陣列來處理特殊要求。
一、 關於大頭?
我們整理題目可以發現,關於大頭,本題大概有這樣兩個限制:
-
大頭必須吃掉 \(1\) 號節點
-
大頭必須吃掉 \(k\) 個節點
所以,我們大可以用 f[i][j]
表示對於 \(i\) 號節點,它的子樹一共有 \(j\) 個節點被大頭吃掉了。
這樣有什麼好處呢?我們可以發現,最後我們只需要輸出 f[1][k]
就萬事大吉了——等等,第一個限制是不是還沒考慮?
這樣,我們可以再把 f
開一維 [0/1]
,用 f[i][j][0/1]
表示 i
號節點有沒有被大頭吃掉,\(0\) 表示不是大頭吃的,\(1\) 表示是大頭吃的。現在我們可以輸出 f[1][k][1]
,就完完全全考慮完了大頭的限制啦!
二、轉移方程?
對於每一個節點 \(u\) 來說,我們需要把 f[u][0-k][0/1]
全部更新才算更新完了所有狀態。考慮什麼情況下會增加難受值——如果有 \(3\)
對於大頭不吃 \(u\) 點的情況,可能由 \(v,u\) 大頭都不吃和吃 \(v\) 不吃 \(u\) 兩種情況轉移過來;吃 \(u\) 點同理。dp 方程大概長這樣,反正就是考慮一下各種可能就好啦。
\[\begin{aligned} f_{u,j,0}&=min(f_{u,j,0},min(f_{v,t,0}+f_{u,j-t,0}+[m==2]* w,f_{v,t,1}+f_{u,j-t,0}))\\ f_{u,j,1}&=min(f_{u,j,1},min(f_{v,t,1}+f_{u,j-t,1}+w,f_{v,t,0}+f_{u,j-t,1}))\\ \end{aligned} \](\(u,v,w\)
這樣我們就可以把它放在 dfs 裡得出答案啦!
for(int j=0;j<=k;++j){
for(int t=0;t<=j;++t){
f[u][j][0]=min(f[u][j][0],min(f[v][t][0]+f[u][j-t][0]+(m==2)*w,f[v][t][1]+f[u][j-t][0]));
f[u][j][1]=min(f[u][j][1],min(f[v][t][1]+f[u][j-t][1]+w,f[v][t][0]+f[u][j-t][1]));
}
}
然而它掛了……我們思考一下為什麼呢?原來,這題在更新 f[u][j]
的時候會被 f[u][j-t]
更新,所以這個東西就開始自己搞自己了……所以我們要用一個數組先去記錄一下 f[u]
,然後就可以放心做 dp 了。
完了嗎?還是沒有……我就是卡在了 \(-1\) 上。我們思考一下,什麼情況下會無解呢?注意到題目中的一個條件——每個組都要有果子,也就是每個頭都要吃到果子。如果大頭吃剩下的比剩下的頭數要少,是不是就無解了?所以我們只需要判斷 \(n-k<m-1\) 就可以判掉無解情況了。
點選檢視程式碼
#include<iostream>
#include<cstdio>
#include<cstring>
using namespace std;
const int N=1000+4;
struct Edge{int v,w,nxt;}e[N<<1];
int h[N],f[N][N][2],tmp[N][2];
int tot,m,n,k;
inline void add(int u,int v,int w){
e[++tot]=(Edge){v,w,h[u]};
h[u]=tot;
}
void dfs(int u,int fa){
f[u][0][0]=f[u][1][1]=0;
for(int i=h[u];i;i=e[i].nxt){
int v=e[i].v,w=e[i].w;
if(v==fa) continue;
dfs(v,u);
memcpy(tmp,f[u],sizeof(f[u]));
memset(f[u],0x3f,sizeof(f[u]));
for(int j=0;j<=k;++j){
for(int t=0;t<=j;++t){
f[u][j][0]=min(f[u][j][0],min(f[v][t][0]+tmp[j-t][0]+(m==2)*w,f[v][t][1]+tmp[j-t][0]));
f[u][j][1]=min(f[u][j][1],min(f[v][t][1]+tmp[j-t][1]+w,f[v][t][0]+tmp[j-t][1]));
}
}
}
}
int main(){
memset(f,0x3f,sizeof(f));
scanf("%d%d%d",&n,&m,&k);
for(int i=1,u,v,w;i<n;++i){
scanf("%d%d%d",&u,&v,&w);
add(u,v,w),add(v,u,w);
}
if(n-k<m-1){printf("-1\n");return 0;}
dfs(1,0);
printf("%d\n",f[1][k][1]);
return 0;
}