[總結] 容斥原理
阿新 • • 發佈:2021-08-12
[總結] 容斥原理
本篇文章用於介紹簡單的容斥原理。
定義
- 加上多減的,減去多加的。(可以畫 Venn 圖來理解)
適用條件
一般套路如下:
-
總方案數容易求得。
-
把 每個集合看成打破 \(|S|\) 條限制的非法方案集合,最後求得的 \(|\bigcup_{i=1}^nS_i|\) 的意義就是 所有非法方案總數。
-
一般可以通過 等價轉化 將容斥過程中的方案數計算出來。
-
資料範圍較小(允許狀壓)。
我們把每一條限制看成集合元素,列舉每一個集合,因為最後是減去不合法方案,所以容斥係數全部由 \((-1)^{|S|-1}\) 變為 \((-1)^{|S|}\)。
其中 \(|S_1 \cup S_2|\)
由於容斥過程中是求 \(∩\),所以考慮第二種。
換句話說就是 列舉所有條件集合,這就相當於 枚舉了所有相交後的集合。
例題
[HAOI2008]硬幣購物
給你四種面值 \(c_i\) 的硬幣,每種硬幣有無限個,求得在限制每種硬幣使用個數 \(d_i\) 的前提下湊成 \(s\) 價值的方案數。
先考慮沒有限制,跑一遍完全揹包即可,可以得到總方案數。
考慮打破限制的情況:
當前集合為 \(S\),有第 \(i\) 條限制,那麼我就強制選 \((d_i+1)\times c_i\)
這體現了上面說的:一般可以通過等價轉化將容斥過程中的方案數計算出來。
比較板的容斥。
#include <iostream> #include <cstdio> #include <cstring> #include <algorithm> using namespace std; template <typename T> inline T read(){ T x=0;char ch=getchar();bool fl=false; while(!isdigit(ch)){if(ch=='-')fl=true;ch=getchar();} while(isdigit(ch)){ x=(x<<3)+(x<<1)+(ch^48);ch=getchar(); } return fl?-x:x; } #define LL long long const int maxn = 1e5 + 10; LL f[maxn]; LL val[10],d[10],q,s; void init(){ f[0]=1; for(int i=1;i<=4;i++) for(int j=val[i];j<=100000;j++)f[j]+=f[j-val[i]]; } #define read() read<LL>() int main(){ for(int i=1;i<=4;i++)val[i]=read();q=read(); init(); while(q--){ for(int i=1;i<=4;i++)d[i]=read(); s=read(); LL ans=0; for(int S=0;S<(1<<4);S++){ LL sz=0,fl,sum=0; for(int i=0;i<4;i++)if((S>>i)&1){ sz++;sum+=(d[i+1]+1)*val[i+1]; } if(s<sum)continue; fl=((sz&1)?(-1):1); ans+=fl*f[s-sum]; } printf("%lld\n",ans); } return 0; }
CF451E Devu and Flowers
多重集的組合數。(早期程式碼)
#include <iostream>
#include <cstdio>
#include <set>
#include <cstring>
#include <algorithm>
using namespace std;
#define LL long long
const int P = 1000000007;
const int maxn = 22;
LL m,a[maxn],ans=0;
LL inv[maxn],n;
LL power(LL a,LL b){
LL res=1;
while(b){
if(b&1)res=res*1LL*a%P;
a=1LL*a*a%P;
b>>=1;
}
return res;
}
LL C(LL n,LL m){
if(n<0 || m<0 || n<m)return 0;
n%=P;
if(n==0 || m==0)return 1LL;
LL res=1;
for(int i=0;i<m;i++)res=res*(n-i)%P;
for(int i=1;i<=m;i++)res=res*inv[i]%P;
return res;
}
void init(){
for(int i=1;i<=20;i++)inv[i]=power(i,P-2);
return ;
}
int main(){
init();
scanf("%lld%lld",&n,&m);
for(int i=1;i<=n;i++)scanf("%lld",a+i);
for(int s=0;s<(1<<n);s++){
if(s==0){
ans=(ans+C(n+m-1,n-1))%P;
//cerr<<ans<<endl;
}
else{
LL tmp=n+m;
int p=0;
for(int i=1;i<=n;i++){
if(s&(1<<i-1)){
p++;
tmp-=a[i];
}
}
tmp-=p+1;
if(p&1){
ans=(ans-C(tmp,n-1))%P;
}
else{
ans=(ans+C(tmp,n-1))%P;
}
}
}
printf("%lld",(ans+P)%P);
return 0;
}