1. 程式人生 > >藍魔法師(樹形dp)

藍魔法師(樹形dp)

時間限制:C/C++ 1秒,其他語言2秒 空間限制:C/C++ 262144K,其他語言524288K 64bit IO Format: %lld

題目描述

“你,你認錯人了。我真的,真的不是食人魔。”--藍魔法師

給出一棵樹,求有多少種刪邊方案,使得刪後的圖每個連通塊大小小於等於k,兩種方案不同當且僅當存在一條邊在一個方案中被刪除,而在另一個方案中未被刪除,答案對998244353取模

輸入描述:

第一行兩個整數n,k, 表示點數和限制
2 <= n <= 2000, 1 <= k <= 2000
接下來n-1行,每行包括兩個整數u,v,表示u,v兩點之間有一條無向邊
保證初始圖聯通且合法

輸出描述:

共一行,一個整數表示方案數對998244353取模的結果

示例1

輸入

複製

5 2
1 2
1 3
2 4
2 5

輸出

複製

7

題解:

用dp[i][j]表示以i為根結點的子樹中,當前包含節點i的連通塊的大小為j的方案個數。

其中dp[x][0]特殊,用來表示∑dp[x][i]  (1≤i≤k)

然後每一個節點u來說,他的一部分v兒子已經搜完了,此時的dp[u][p]就是加上了搜到過的v的,聯通塊大小為p的個數。然後繼續搜的時候就繼續用dp【u】【i】*dp【v】【k-i】加到原來的dp【u】【p】上。這樣是不重複的。

程式碼:

#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn=10000;
const int mod=998244353;
LL dp[maxn][maxn];
vector<int>G[maxn];
int temp[maxn];
int siz[maxn];
int n,k;
void dfs(int u,int fa){
    dp[u][1]=1;
    siz[u]=1;
    int v;
    for(int i=0;i<G[u].size();i++){
        v=G[u][i];
        if(v==fa)
            continue;
        dfs(v,u);
        memset(temp,0,sizeof(temp));
        for(int ii=1;ii<=siz[u];ii++){
            for(int j=0;j<=min(k-ii,siz[v]);j++){
                temp[ii+j]=(temp[ii+j]%mod+dp[u][ii]*dp[v][j]%mod)%mod;
            }
        }
        for(int ii=1;ii<=k;ii++){
            dp[u][ii]=temp[ii];
        }
        siz[u]+=siz[v];
    }
    for(int i=1;i<=k;i++)///到最末端節點
    {
    ///    cout<<"*******dp["<<u<<"]["<<0<<"]:"<<dp[u][0]<<endl;
        dp[u][0] = (dp[u][0]+dp[u][i])%mod;
    }
}
void dfs1(int u, int p)
{
   /// cout<<"u:"<<u<<"     p:"<<p<<endl;
	int i, j, v, q;
	dp[u][1] = siz[u] = 1;
	///cout<<"dp["<<u<<"]["<<0<<"]:"<<dp[u][0]<<endl;
	for(q=0;q<G[u].size();q++)
	{
		v = G[u][q];
		if(v==p)
			continue;
		dfs(v, u);
		memset(temp, 0, sizeof(temp));
		for(i=1;i<=siz[u];i++)
		{
			for(j=0;j<=min(siz[v], k-i);j++)
            {
               /// cout<<"i:"<<i<<"   j:"<<j<<"  dp["<<u<<"]["<<i<<"]:"<<dp[u][i]<<"  dp["<<v<<"]["<<j<<"]:"<<dp[v][j]<<endl;
                temp[i+j] = (temp[i+j]+dp[u][i]*dp[v][j])%mod;
            }
		}
		for(int i=1;i<=k;i++){
            dp[u][i]=temp[i];
		}
		///memcpy(dp[u], temp, sizeof(temp));
		siz[u] += siz[v];
	}
	for(i=1;i<=k;i++)///到最末端節點
    {
    ///    cout<<"*******dp["<<u<<"]["<<0<<"]:"<<dp[u][0]<<endl;
        dp[u][0] = (dp[u][0]+dp[u][i])%mod;
    }
}
int main()
{
    int a,b;
    while(scanf("%d%d",&n,&k)!=EOF){
        for(int i=1;i<=n-1;i++){
            scanf("%d%d",&a,&b);
            G[a].push_back(b);
            G[b].push_back(a);
        }
        dfs(1,0);
        cout<<dp[1][0]<<endl;
    }
    return 0;
}