AGC019E - Shuffle and Swap 題解
題目連結:E - Shuffle and Swap
題目大意:洛谷
題解:這一道題有兩個做法。
Solution 1
考慮有 \(n\) 個位置上 \(A,B\) 均為 1
,設為位置 A ,有 \(2m\) 個位置上 \(A\) 為 1
, 或 \(B\) 為 1
,設為位置 B(不包含位置 A),那麼我們的目標就是讓最後位置 B 的個數減少為 0,那麼我們每一次交換位置 A 中的兩個數並不會是位置 B 的個數發生變化,所以不管它,如果我們交換位置 B 中的兩個數則會使位置 B 的個數減 1,這種這種情況的方案數是 \(m^2\) ,如果我們交換一個位置 A 上的數和一個位置 B 上的數,則會使位置 A 的個數減 1 ,位置 B 的個數不變。
因此,我們的轉移方程就是 \(f_{n,m}= f_{n,m-1}\times m^2+f_{n-1,m}\times n\times m\) 。
最後我們考慮位置 A 還剩餘的方案數,\(ans =\sum_{i=0}^{n} f_{n-i,m} \times (i!)^2 \times \binom{n}{i}\times \binom{n+m}{i}\)。
時間複雜度和空間複雜度均為 \(O(n^2)\)。
Solution 2
考慮一個重要的轉化:題目中的 \((k!)^2\) 中方案我們可以分成兩個步驟,第一個步驟是給每一個 \(A\) 中的 1
匹配一個 \(B\) 中的 1
,第二個步驟是給這些情況重新排列。
現在我們更改一下 \(n,m\) 的定義,令 \(n\) 表示 \(A,B\) 上有一個位置為 1
的方案數, \(m\) 的意義同 Solution 1 中的不變。
考慮步驟一,那麼如果我們對於 \(A\) 中的 1
的位置,向和它匹配的 \(B\) 中的位置連一條邊,那麼我們會發現,整張圖被我們分成了若干條鏈和若干個環,鏈的個數恰好是 \(m\) 個,並且,鏈的選擇是有順序要求的,即必須從鏈首選到鏈尾,而環的選擇則沒有要求了,接下來我們考慮使用生成函式來表示這個東西,因為兩條鏈或者兩個環在組合的時候是需要乘上組合數的,所以考慮用指數型生成函式來解決。
鏈的指數型生成函式:(假設起點已經確定,所以在結束之後還需要乘上\(m!\)
環的指數型生成函式:(\([x^i]G(x)\)表示 \(i\) 個點的環的方案數。)
\[G(x)=\sum_{i=1}^{\infty} \frac{(i-1)!\times i!\times x^i}{i!\times i!} = -\ln(1-x) \]所以我們需要將環和鏈組合起來,因為鏈的組合是有序的,而環的組合是無序的,所以最後的結果就是:
\[n!\times m!\times (n-m)! \times \sum_{i=0}^{n-m} ([x^i]F^m(x)) ([x^i]\exp(G(x))) \]然後把函式帶進去展開得到:
\[n!\times m!\times (n-m)! \times \sum_{i=0}^{n-m} [x^i](\frac{e^x-1}{x})^m \]然後就可以直接計算答案了,時間複雜度 \(O(n\log n)\),空間複雜度 \(O(n)\),
Solution 1 的程式碼:
#include <cstdio>
int quick_power(int a,int b,int Mod){
int ans=1;
while(b){
if(b&1){
ans=1ll*ans*a%Mod;
}
b>>=1;
a=1ll*a*a%Mod;
}
return ans;
}
const int Maxn=10000;
const int Mod=998244353;
int f[Maxn+5][Maxn+5];
int n,k;
char a[Maxn+5],b[Maxn+5];
int s_1,s_2;
int frac[Maxn+5],inv_f[Maxn+5];
void init(){
frac[0]=1;
for(int i=1;i<=Maxn;i++){
frac[i]=1ll*frac[i-1]*i%Mod;
}
inv_f[Maxn]=quick_power(frac[Maxn],Mod-2,Mod);
for(int i=Maxn-1;i>=0;i--){
inv_f[i]=1ll*inv_f[i+1]*(i+1)%Mod;
}
}
int C(int n,int m){
return 1ll*frac[n]*inv_f[m]%Mod*inv_f[n-m]%Mod;
}
int main(){
init();
scanf("%s",a+1);
scanf("%s",b+1);
while(a[++n]!='\0');
n--;
for(int i=1;i<=n;i++){
if(a[i]=='1'){
k++;
}
if(a[i]=='1'&&b[i]=='1'){
s_1++;
}
else if(a[i]=='1'){
s_2++;
}
}
f[0][0]=1;
for(int i=0;i<=s_1;i++){
for(int j=1;j<=s_2;j++){
if(i==0&&j==0){
continue;
}
f[i][j]=(f[i][j]+1ll*f[i][j-1]*j%Mod*j)%Mod;
if(i>0){
f[i][j]=(f[i][j]+1ll*f[i-1][j]*i%Mod*j)%Mod;
}
}
}
int ans=0;
for(int i=0;i<=s_1;i++){
ans=(ans+1ll*f[s_1-i][s_2]*frac[i]%Mod*frac[i]%Mod*C(s_1,i)%Mod*C(k,i))%Mod;
}
printf("%d\n",ans);
return 0;
}
Solution 2 的程式碼:
#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
int quick_power(int a,int b,int Mod){
int ans=1;
while(b){
if(b&1){
ans=1ll*ans*a%Mod;
}
b>>=1;
a=1ll*a*a%Mod;
}
return ans;
}
const int Maxn=40000;
const int G=3;
const int Mod=998244353;
int n,m,len;
char a[Maxn+5],b[Maxn+5];
void NTT(int *a,int flag,int n){
static int R[Maxn+5];
int len=1,L=0;
while(len<n){
len<<=1;
L++;
}
for(int i=0;i<len;i++){
R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
}
for(int i=0;i<len;i++){
if(i<R[i]){
swap(a[i],a[R[i]]);
}
}
for(int j=1;j<len;j<<=1){
int T=quick_power(G,(Mod-1)/(j<<1),Mod);
for(int k=0;k<len;k+=(j<<1)){
for(int l=0,t=1;l<j;l++,t=1ll*t*T%Mod){
int Nx=a[k+l],Ny=1ll*t*a[k+l+j]%Mod;
a[k+l]=(Nx+Ny)%Mod;
a[k+l+j]=(Nx-Ny+Mod)%Mod;
}
}
}
if(flag==-1){
reverse(a+1,a+len);
for(int i=0,t=quick_power(len,Mod-2,Mod);i<len;i++){
a[i]=1ll*a[i]*t%Mod;
}
}
}
void find_inv(int *a,int *b,int len){
static int c[Maxn+5],d[Maxn+5];
if(len==1){
b[0]=quick_power(a[0],Mod-2,Mod);
return;
}
find_inv(a,b,len>>1);
for(int i=0;i<len;i++){
c[i]=a[i];
d[i]=b[i];
}
for(int i=len;i<(len<<1);i++){
c[i]=d[i]=0;
}
NTT(c,1,len<<1);
NTT(d,1,len<<1);
for(int i=0;i<(len<<1);i++){
d[i]=1ll*d[i]*d[i]%Mod*c[i]%Mod;
}
NTT(d,-1,len<<1);
for(int i=0;i<len;i++){
b[i]=((b[i]<<1)%Mod-d[i]+Mod)%Mod;
}
for(int i=len;i<(len<<1);i++){
b[i]=0;
}
}
void find_dev(int *a,int len){
for(int i=0;i<len;i++){
a[i]=1ll*(i+1)*a[i+1]%Mod;
}
a[len-1]=0;
}
void find_dev_inv(int *a,int len){
for(int i=len-1;i>0;i--){
a[i]=1ll*quick_power(i,Mod-2,Mod)*a[i-1]%Mod;
}
a[0]=0;
}
void find_ln(int *a,int *b,int n){
static int c[Maxn+5];
for(int i=0;i<n;i++){
c[i]=a[i];
}
find_dev(c,n);
int len=1;
while(len<n){
len<<=1;
}
find_inv(a,b,len);
for(int i=n;i<len;i++){
b[i]=0;
}
for(int i=len;i<(len<<1);i++){
b[i]=c[i]=0;
}
NTT(b,1,len<<1);
NTT(c,1,len<<1);
for(int i=0;i<(len<<1);i++){
b[i]=1ll*b[i]*c[i]%Mod;
}
NTT(b,-1,len<<1);
find_dev_inv(b,len<<1);
for(int i=n;i<(len<<1);i++){
b[i]=0;
}
}
void find_exp(int *a,int *b,int len){
static int c[Maxn+5];
if(len==1){
b[0]=1;
return;
}
find_exp(a,b,len>>1);
find_ln(b,c,len);
c[0]=(a[0]+1-c[0]+Mod)%Mod;
for(int i=1;i<len;i++){
c[i]=(a[i]-c[i]+Mod)%Mod;
}
for(int i=len;i<(len<<1);i++){
b[i]=c[i]=0;
}
NTT(b,1,len<<1);
NTT(c,1,len<<1);
for(int i=0;i<(len<<1);i++){
b[i]=1ll*b[i]*c[i]%Mod;
}
NTT(b,-1,len<<1);
for(int i=len;i<(len<<1);i++){
b[i]=c[i]=0;
}
for(int i=0;i<len;i++){
printf("%d ",b[i]);
}
puts("");
}
int f[Maxn+5],g[Maxn+5];
int frac[Maxn+5],inv_f[Maxn+5];
void init(){
frac[0]=1;
for(int i=1;i<=Maxn;i++){
frac[i]=1ll*frac[i-1]*i%Mod;
}
inv_f[Maxn]=quick_power(frac[Maxn],Mod-2,Mod);
for(int i=Maxn-1;i>=0;i--){
inv_f[i]=1ll*inv_f[i+1]*(i+1)%Mod;
}
}
int main(){
init();
scanf("%s",a+1);
scanf("%s",b+1);
while(a[++len]!='\0');
for(int i=1;i<=len;i++){
if(a[i]=='1'){
n++;
if(b[i]=='0'){
m++;
}
}
}
for(int i=0;i<=n-m;i++){
f[i]=inv_f[i+1];
}
int len=1;
while(len<=n-m){
len<<=1;
}
find_ln(f,g,len);
memset(f,0,sizeof f);
for(int i=0;i<=n-m;i++){
f[i]=1ll*g[i]*m%Mod;
}
memset(g,0,sizeof g);
find_exp(f,g,len);
int ans=0;
for(int i=0;i<=n-m;i++){
f[i]=g[i];
ans=(ans+f[i])%Mod;
}
ans=1ll*ans*frac[m]%Mod*frac[n-m]%Mod*frac[n]%Mod;
printf("%d\n",ans);
return 0;
}