1. 程式人生 > >poj1155樹形dp+揹包

poj1155樹形dp+揹包

題目大意:每個使用者必須連一個發射器,發射的訊號從一個點到另一個點需要費用,問電視臺在不虧本的情況下最多可以給多少個使用者發射訊號。

思路:dp[i][j]表示節點i發射j個訊號給使用者的盈利。注意:dp陣列初始化時,dp[i][0] = 0;其餘的負無窮,否則會出bug。自己跟著程式走一遍思路會清晰很多。

#include <iostream>
#include <cstdio>
#include <string>
#include <cstring>
#include <fstream>
#include <algorithm>
#include <cmath>
#include <queue>
#include <stack>
#include <vector>
#include <map>
#include <set>
#include <iomanip>

using namespace std;
//#pragma comment(linker, "/STACK:102400000,102400000")
#define maxn 3005
#define MOD 1000000007
#define mem(a , b) memset(a , b , sizeof(a))
#define LL long long
#define INF 100000000
int n , id , m;
int res[maxn] , tmp[maxn];
struct edge
{
    int u , v , w;
    int next;
}E[maxn];
int head[maxn];
int dp[maxn][maxn];
int val[maxn];

void add(int u , int v , int w)
{
    E[id].u = u;
    E[id].v = v;
    E[id].w = w;
    E[id].next = head[u];
    head[u] = id++;
}

void dfs(int root)
{
    for(int i = head[root] ; i != -1 ; i = E[i].next)
    {
        dfs(E[i].v);
        for(int j = 0 ; j <= res[root] ; j ++) tmp[j] = dp[root][j];
        for(int j = 0 ; j <= res[root] ; j ++)
        {
            for(int k = 1 ; k <= res[E[i].v] ; k ++)
            {
                dp[root][j+k] = max(dp[root][j+k] , tmp[j] + dp[E[i].v][k] - E[i].w);
            }
        }
        res[root] += res[E[i].v];
    }
//    cout << root << ":" << endl;
//    for(int i = 0 ; i <= res[root] ; i ++) cout << dp[root][i] << " ";
//    cout << endl;
}

int main()
{
    while(scanf("%d %d" , &n , &m)!= EOF)
    {
        id = 0;
        mem(head , -1);
        mem(dp , 0);
        for(int i = 1 ;i <= n ; i ++)
        {
            for(int j = 1 ; j <= m ;j ++)
                dp[i][j] = -MOD;
        }
        int v , w , up = n - m , num;
        for(int i = 1 ; i <= up ; i ++)
        {
            scanf("%d" , &num);
            for(int j = 0 ; j < num ; j  ++)
            {
                scanf("%d %d" , &v , &w);
                add(i , v , w);
            }
            res[i] = 0;
        }
        for(int i = up+1 ; i <= n ; i ++) {scanf("%d" , &val[i]) ; dp[i][1] = val[i] ; res[i] = 1;}
        dfs(1);
        for(int i = m ; i >= 0 ; i --)
        {
           // cout << dp[1][i] << endl;
            if(dp[1][i] >= 0)
            {
                printf("%d\n" , i);
                break;
            }
        }
    }
    return 0;
}