1. 程式人生 > 其它 >[總結] 容斥原理

[總結] 容斥原理

[總結] 容斥原理

本篇文章用於介紹簡單的容斥原理。

定義

  • 加上多減的,減去多加的。(可以畫 Venn 圖來理解)

適用條件

一般套路如下:

  • 總方案數容易求得。

  • 每個集合看成打破 \(|S|\) 條限制的非法方案集合,最後求得的 \(|\bigcup_{i=1}^nS_i|\) 的意義就是 所有非法方案總數

  • 一般可以通過 等價轉化 將容斥過程中的方案數計算出來。

  • 資料範圍較小(允許狀壓)。

我們把每一條限制看成集合元素,列舉每一個集合,因為最後是減去不合法方案,所以容斥係數全部由 \((-1)^{|S|-1}\) 變為 \((-1)^{|S|}\)

其中 \(|S_1 \cup S_2|\)

就等價於打破 \(S_1\)\(S_2\) 至少一個的方案數,\(|S_1 \cap S_2|\) 表示打破 \(S_1\)\(S_2\) 所有限制的方案數。

由於容斥過程中是求 \(∩\),所以考慮第二種。

換句話說就是 列舉所有條件集合,這就相當於 枚舉了所有相交後的集合

例題

[HAOI2008]硬幣購物

給你四種面值 \(c_i\) 的硬幣,每種硬幣有無限個,求得在限制每種硬幣使用個數 \(d_i\) 的前提下湊成 \(s\) 價值的方案數。

先考慮沒有限制,跑一遍完全揹包即可,可以得到總方案數。

考慮打破限制的情況:

當前集合為 \(S\),有第 \(i\) 條限制,那麼我就強制選 \((d_i+1)\times c_i\)

的第 \(i\) 種硬幣即可。

這體現了上面說的:一般可以通過等價轉化將容斥過程中的方案數計算出來。

比較板的容斥。

#include <iostream>
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
template <typename T>
inline T read(){
	T x=0;char ch=getchar();bool fl=false;
	while(!isdigit(ch)){if(ch=='-')fl=true;ch=getchar();}
	while(isdigit(ch)){
		x=(x<<3)+(x<<1)+(ch^48);ch=getchar();
	}
	return fl?-x:x;
}
#define LL long long
const int maxn = 1e5 + 10;
LL f[maxn];
LL val[10],d[10],q,s;
void init(){
	f[0]=1;
	for(int i=1;i<=4;i++)
		for(int j=val[i];j<=100000;j++)f[j]+=f[j-val[i]];
}
#define read() read<LL>()
int main(){
	for(int i=1;i<=4;i++)val[i]=read();q=read();
	init();
	while(q--){
		for(int i=1;i<=4;i++)d[i]=read();
		s=read();
		LL ans=0;
		for(int S=0;S<(1<<4);S++){
			LL sz=0,fl,sum=0;
			for(int i=0;i<4;i++)if((S>>i)&1){
				sz++;sum+=(d[i+1]+1)*val[i+1];
			}
			if(s<sum)continue;
			fl=((sz&1)?(-1):1);
			ans+=fl*f[s-sum];
		}
		printf("%lld\n",ans);
	}
	return 0;
}

CF451E Devu and Flowers

多重集的組合數。(早期程式碼)

#include <iostream>
#include <cstdio>
#include <set>
#include <cstring>
#include <algorithm>
using namespace std;
#define LL long long
const int P = 1000000007;
const int maxn = 22;
LL m,a[maxn],ans=0;
LL inv[maxn],n;
LL power(LL a,LL b){
	LL res=1;
	while(b){
		if(b&1)res=res*1LL*a%P;
		a=1LL*a*a%P;
		b>>=1;
	}
	return res;
}
LL C(LL n,LL m){
	if(n<0 || m<0 || n<m)return 0;
	n%=P;
	if(n==0 || m==0)return 1LL;
	LL res=1;
	for(int i=0;i<m;i++)res=res*(n-i)%P;
	for(int i=1;i<=m;i++)res=res*inv[i]%P;
	return res;
}
void init(){
	for(int i=1;i<=20;i++)inv[i]=power(i,P-2);
	return ;
}
int main(){
	init();
	scanf("%lld%lld",&n,&m);
	for(int i=1;i<=n;i++)scanf("%lld",a+i);
	for(int s=0;s<(1<<n);s++){
		if(s==0){
			ans=(ans+C(n+m-1,n-1))%P; 
			//cerr<<ans<<endl;
		}
		else{
			LL tmp=n+m;
			int p=0;
			for(int i=1;i<=n;i++){
				if(s&(1<<i-1)){
					p++;
					tmp-=a[i];
				}
			}
			tmp-=p+1; 
			if(p&1){
				ans=(ans-C(tmp,n-1))%P;
			}
			else{
				ans=(ans+C(tmp,n-1))%P;
			}
		}
	}
	printf("%lld",(ans+P)%P);
	return 0;
}