1. 程式人生 > 實用技巧 >Codeforces 512D - Fox And Travelling(樹上揹包)

Codeforces 512D - Fox And Travelling(樹上揹包)

題面傳送門

題意:
給出一張無向圖,每次你可以選擇一個度數 \(\leq 1\) 的點並將其刪除。
問對於 \(k=0,1,2,\dots,n\),有多少個刪除 \(k\) 個點的序列,答案模 \(10^9+9\)
\(1\leq n\leq 100\)

很容易發現,環上的點是永遠不可能被刪除的。於是我們對原圖做一遍拓撲排序,求出哪些點可能被刪除。
顯然這些能被刪除的點一定會形成一棵森林。
而森林中的每個連通塊又可以分為兩類:一種是該連通塊在原圖中屬於的連通塊中有環相連。另一種是該連通塊在原圖中本身就是一棵樹。
對於第一種情況,一定有且僅有一個點 \(x\) 與環相鄰。此時 \(x\) 一定是連通塊中最後一個被刪除的。我們就以 \(x\)

為根跑一遍樹上揹包。具體來說,設 \(dp_{i,j}\) 表示 \(i\) 子樹中選了 \(j\) 個點的方案數。由於我們這是有根樹,必須選完 \(x\) 子樹內所有點才能選 \(x\)。轉移就正常求個卷積,再乘上個組合數的係數即可。由於任意兩點之間最多被合併一次,複雜度 \(n^2\)
對於第二種情況,我們並沒有給這棵樹定根。所以我們考慮以樹上每個點為根都跑一次樹上揹包。但這樣會重複計算。對於選了 \(i\) 個點的方案,在以未被選擇 \(s-i\) 個點為根情況下都被算了一次。故答案除以 \(s-i\)
然後最後再總體跑一遍揹包即可。時間複雜度 \(\mathcal O(n^3)\)

#include <bits/stdc++.h>
using namespace std;
#define fi first
#define se second
#define fz(i,a,b) for(int i=a;i<=b;i++)
#define fd(i,a,b) for(int i=a;i>=b;i--)
#define ffe(it,v) for(__typeof(v.begin()) it=v.begin();it!=v.end();it++)
#define fill0(a) memset(a,0,sizeof(a))
#define fill1(a) memset(a,-1,sizeof(a))
#define fillbig(a) memset(a,63,sizeof(a))
#define pb push_back
#define ppb pop_back
#define mp make_pair
template<typename T1,typename T2> void chkmin(T1 &x,T2 y){if(x>y) x=y;}
template<typename T1,typename T2> void chkmax(T1 &x,T2 y){if(x<y) x=y;}
typedef pair<int,int> pii;
typedef long long ll;
template<typename T> void read(T &x){
	x=0;char c=getchar();T neg=1;
	while(!isdigit(c)){if(c=='-') neg=-1;c=getchar();}
	while(isdigit(c)) x=x*10+c-'0',c=getchar();
	x*=neg;
}
const int MAXN=100;
const int MAXM=4950;
const int MOD=1e9+9;
int qpow(int x,int e){int ret=1;for(;e;x=1ll*x*x%MOD,e>>=1) if(e&1) ret=1ll*ret*x%MOD;return ret;}
int n,m,deg[MAXN+5],to[MAXM*2+5],ec=0,nxt[MAXM*2+5],hd[MAXN+5];
void adde(int u,int v){to[++ec]=v;nxt[ec]=hd[u];hd[u]=ec;}
bool vis[MAXN+5];
int from[MAXN+5],cmp=0,siz[MAXN+5],rt[MAXN+5],con[MAXN+5],c[MAXN+5][MAXN+5];
void dfs0(int x){
	if(from[x]) return;from[x]=cmp;siz[cmp]++;
//	printf("%d %d\n",x,cmp);
	for(int e=hd[x];e;e=nxt[e]) if(vis[to[e]]) dfs0(to[e]);
}
int dp[MAXN+5][MAXN+5],s[MAXN+5],tmp[MAXN+5],dpc[MAXN+5][MAXN+5];
void dfs(int x,int f){
//	printf("%d %d\n",x,f);
	s[x]=1;dp[x][0]=1;
	for(int e=hd[x];e;e=nxt[e]){
		int y=to[e];if(y==f) continue;dfs(y,x);
		memset(tmp,0,sizeof(tmp));
		for(int i=0;i<=s[x];i++) for(int j=0;j<=s[y];j++)
			tmp[i+j]=(tmp[i+j]+1ll*dp[x][i]*dp[y][j]%MOD*c[i+j][i]%MOD)%MOD;
		s[x]+=s[y];
		for(int i=0;i<=s[x];i++) dp[x][i]=tmp[i];
	} dp[x][s[x]]=dp[x][s[x]-1];
//	for(int i=0;i<=s[x];i++) printf("%d %d %d\n",x,i,dp[x][i]);
}
int iv[MAXN+5],ans[MAXN+5]; 
int main(){
	scanf("%d%d",&n,&m);
	for(int i=0;i<=MAXN;i++){
		c[i][0]=1;
		for(int j=1;j<=i;j++) c[i][j]=(c[i-1][j]+c[i-1][j-1])%MOD;
	}
	iv[0]=1;for(int i=1;i<=MAXN;i++) iv[i]=qpow(i,MOD-2);
	for(int i=1;i<=m;i++){
		int u,v;scanf("%d%d",&u,&v);
		adde(u,v);adde(v,u);deg[u]++;deg[v]++;
	}
	queue<int> q;
	for(int i=1;i<=n;i++) if(deg[i]<=1) q.push(i);
	while(!q.empty()){
		int x=q.front();q.pop();vis[x]=1;//printf("%d\n",x);
		for(int e=hd[x];e;e=nxt[e]){
			int y=to[e];deg[y]--;
			if(deg[y]<=1&&!vis[y]) q.push(y);
		}
	}
	for(int i=1;i<=n;i++) if(vis[i]&&!from[i]) cmp++,dfs0(i);
	for(int i=1;i<=n;i++) for(int e=hd[i];e;e=nxt[e]) if(vis[i]&&!vis[to[e]]) con[i]=to[e],rt[from[i]]=i;
//	for(int i=1;i<=cmp;i++) printf("%d\n",rt[i]);
	for(int i=1;i<=n;i++) if(vis[i]){
		if(rt[from[i]]&&i==rt[from[i]]){
			memset(dp,0,sizeof(dp));memset(s,0,sizeof(s));
			dfs(i,con[i]);
			for(int j=0;j<=siz[from[i]];j++) dpc[from[i]][j]=dp[i][j];
		} else if(!rt[from[i]]){
			memset(dp,0,sizeof(dp));memset(s,0,sizeof(s));
			dfs(i,0);//printf("%d\n",i);
//			for(int j=1;j<=siz[from[i]];j++) printf("%d ",dp[i][j]);printf("\n");
			for(int j=0;j<=siz[from[i]];j++) dpc[from[i]][j]=(dpc[from[i]][j]+dp[i][j])%MOD;
		}
	}
	for(int i=1;i<=cmp;i++) if(!rt[i]){
		for(int j=0;j<=siz[i];j++) dpc[i][j]=1ll*dpc[i][j]*iv[siz[i]-j]%MOD;
	} ans[0]=1;
//	for(int i=1;i<=cmp;i++) for(int j=0;j<=siz[i];j++) printf("%d %d %d\n",i,j,dpc[i][j]); 
	for(int i=1;i<=cmp;i++) for(int j=n;j;j--) for(int k=1;k<=min(siz[i],j);k++)
		ans[j]=(ans[j]+1ll*ans[j-k]*dpc[i][k]%MOD*c[j][k])%MOD;
	for(int i=0;i<=n;i++) printf("%d\n",ans[i]);
	return 0;
}