1. 程式人生 > 實用技巧 >AGC019E - Shuffle and Swap 題解

AGC019E - Shuffle and Swap 題解

題目連結:E - Shuffle and Swap

題目大意:洛谷


題解:這一道題有兩個做法。

Solution 1

考慮有 \(n\) 個位置上 \(A,B\) 均為 1 ,設為位置 A ,有 \(2m\) 個位置上 \(A\)1 , 或 \(B\)1 ,設為位置 B(不包含位置 A),那麼我們的目標就是讓最後位置 B 的個數減少為 0,那麼我們每一次交換位置 A 中的兩個數並不會是位置 B 的個數發生變化,所以不管它,如果我們交換位置 B 中的兩個數則會使位置 B 的個數減 1,這種這種情況的方案數是 \(m^2\) ,如果我們交換一個位置 A 上的數和一個位置 B 上的數,則會使位置 A 的個數減 1 ,位置 B 的個數不變。

因此,我們的轉移方程就是 \(f_{n,m}= f_{n,m-1}\times m^2+f_{n-1,m}\times n\times m\)

最後我們考慮位置 A 還剩餘的方案數,\(ans =\sum_{i=0}^{n} f_{n-i,m} \times (i!)^2 \times \binom{n}{i}\times \binom{n+m}{i}\)

時間複雜度和空間複雜度均為 \(O(n^2)\)

Solution 2

考慮一個重要的轉化:題目中的 \((k!)^2\) 中方案我們可以分成兩個步驟,第一個步驟是給每一個 \(A\) 中的 1 匹配一個 \(B\) 中的 1 ,第二個步驟是給這些情況重新排列。

現在我們更改一下 \(n,m\) 的定義,令 \(n\) 表示 \(A,B\) 上有一個位置為 1 的方案數, \(m\) 的意義同 Solution 1 中的不變。

考慮步驟一,那麼如果我們對於 \(A\) 中的 1 的位置,向和它匹配的 \(B\) 中的位置連一條邊,那麼我們會發現,整張圖被我們分成了若干條鏈和若干個環,鏈的個數恰好是 \(m\) 個,並且,鏈的選擇是有順序要求的,即必須從鏈首選到鏈尾,而環的選擇則沒有要求了,接下來我們考慮使用生成函式來表示這個東西,因為兩條鏈或者兩個環在組合的時候是需要乘上組合數的,所以考慮用指數型生成函式來解決。

鏈的指數型生成函式:(假設起點已經確定,所以在結束之後還需要乘上\(m!\)

到答案中, \([x^i]F(x)\) 表示在鏈的端點之間有 \(i\) 個點的方案數。)

\[F(x)=\sum_{i=0}^{\infty} \frac{i!\times x^i}{i!\times (i+1)!} = \frac{e^x-1}{x} \]

環的指數型生成函式:(\([x^i]G(x)\)表示 \(i\) 個點的環的方案數。)

\[G(x)=\sum_{i=1}^{\infty} \frac{(i-1)!\times i!\times x^i}{i!\times i!} = -\ln(1-x) \]

所以我們需要將環和鏈組合起來,因為鏈的組合是有序的,而環的組合是無序的,所以最後的結果就是:

\[n!\times m!\times (n-m)! \times \sum_{i=0}^{n-m} ([x^i]F^m(x)) ([x^i]\exp(G(x))) \]

然後把函式帶進去展開得到:

\[n!\times m!\times (n-m)! \times \sum_{i=0}^{n-m} [x^i](\frac{e^x-1}{x})^m \]

然後就可以直接計算答案了,時間複雜度 \(O(n\log n)\),空間複雜度 \(O(n)\)

Solution 1 的程式碼:

#include <cstdio>
int quick_power(int a,int b,int Mod){
	int ans=1;
	while(b){
		if(b&1){
			ans=1ll*ans*a%Mod;
		}
		b>>=1;
		a=1ll*a*a%Mod;
	}
	return ans;
}
const int Maxn=10000;
const int Mod=998244353;
int f[Maxn+5][Maxn+5];
int n,k;
char a[Maxn+5],b[Maxn+5];
int s_1,s_2;
int frac[Maxn+5],inv_f[Maxn+5];
void init(){
	frac[0]=1;
	for(int i=1;i<=Maxn;i++){
		frac[i]=1ll*frac[i-1]*i%Mod;
	}
	inv_f[Maxn]=quick_power(frac[Maxn],Mod-2,Mod);
	for(int i=Maxn-1;i>=0;i--){
		inv_f[i]=1ll*inv_f[i+1]*(i+1)%Mod;
	}
}
int C(int n,int m){
	return 1ll*frac[n]*inv_f[m]%Mod*inv_f[n-m]%Mod;
}
int main(){
	init();
	scanf("%s",a+1);
	scanf("%s",b+1);
	while(a[++n]!='\0');
	n--;
	for(int i=1;i<=n;i++){
		if(a[i]=='1'){
			k++;
		}
		if(a[i]=='1'&&b[i]=='1'){
			s_1++;
		}
		else if(a[i]=='1'){
			s_2++;
		}
	}
	f[0][0]=1;
	for(int i=0;i<=s_1;i++){
		for(int j=1;j<=s_2;j++){
			if(i==0&&j==0){
				continue;
			}
			f[i][j]=(f[i][j]+1ll*f[i][j-1]*j%Mod*j)%Mod;
			if(i>0){
				f[i][j]=(f[i][j]+1ll*f[i-1][j]*i%Mod*j)%Mod;
			}
		}
	}
	int ans=0;
	for(int i=0;i<=s_1;i++){
		ans=(ans+1ll*f[s_1-i][s_2]*frac[i]%Mod*frac[i]%Mod*C(s_1,i)%Mod*C(k,i))%Mod;
	}
	printf("%d\n",ans);
	return 0;
}

Solution 2 的程式碼:

#include <cstdio>
#include <cstring>
#include <algorithm>
using namespace std;
int quick_power(int a,int b,int Mod){
	int ans=1;
	while(b){
		if(b&1){
			ans=1ll*ans*a%Mod;
		}
		b>>=1;
		a=1ll*a*a%Mod;
	}
	return ans;
}
const int Maxn=40000;
const int G=3;
const int Mod=998244353;
int n,m,len;
char a[Maxn+5],b[Maxn+5];
void NTT(int *a,int flag,int n){
	static int R[Maxn+5];
	int len=1,L=0;
	while(len<n){
		len<<=1;
		L++;
	}
	for(int i=0;i<len;i++){
		R[i]=(R[i>>1]>>1)|((i&1)<<(L-1));
	}
	for(int i=0;i<len;i++){
		if(i<R[i]){
			swap(a[i],a[R[i]]);
		}
	}
	for(int j=1;j<len;j<<=1){
		int T=quick_power(G,(Mod-1)/(j<<1),Mod);
		for(int k=0;k<len;k+=(j<<1)){
			for(int l=0,t=1;l<j;l++,t=1ll*t*T%Mod){
				int Nx=a[k+l],Ny=1ll*t*a[k+l+j]%Mod;
				a[k+l]=(Nx+Ny)%Mod;
				a[k+l+j]=(Nx-Ny+Mod)%Mod;
			}
		}
	}
	if(flag==-1){
		reverse(a+1,a+len);
		for(int i=0,t=quick_power(len,Mod-2,Mod);i<len;i++){
			a[i]=1ll*a[i]*t%Mod;
		}
	}
}
void find_inv(int *a,int *b,int len){
	static int c[Maxn+5],d[Maxn+5];
	if(len==1){
		b[0]=quick_power(a[0],Mod-2,Mod);
		return;
	}
	find_inv(a,b,len>>1);
	for(int i=0;i<len;i++){
		c[i]=a[i];
		d[i]=b[i];
	}
	for(int i=len;i<(len<<1);i++){
		c[i]=d[i]=0;
	}
	NTT(c,1,len<<1);
	NTT(d,1,len<<1);
	for(int i=0;i<(len<<1);i++){
		d[i]=1ll*d[i]*d[i]%Mod*c[i]%Mod;
	}
	NTT(d,-1,len<<1);
	for(int i=0;i<len;i++){
		b[i]=((b[i]<<1)%Mod-d[i]+Mod)%Mod;
	}
	for(int i=len;i<(len<<1);i++){
		b[i]=0;
	}
}
void find_dev(int *a,int len){
	for(int i=0;i<len;i++){
		a[i]=1ll*(i+1)*a[i+1]%Mod;
	}
	a[len-1]=0;
}
void find_dev_inv(int *a,int len){
	for(int i=len-1;i>0;i--){
		a[i]=1ll*quick_power(i,Mod-2,Mod)*a[i-1]%Mod;
	}
	a[0]=0;
}
void find_ln(int *a,int *b,int n){
	static int c[Maxn+5];
	for(int i=0;i<n;i++){
		c[i]=a[i];
	}
	find_dev(c,n);
	int len=1;
	while(len<n){
		len<<=1;
	}
	find_inv(a,b,len);
	for(int i=n;i<len;i++){
		b[i]=0;
	}
	for(int i=len;i<(len<<1);i++){
		b[i]=c[i]=0;
	}
	NTT(b,1,len<<1);
	NTT(c,1,len<<1);
	for(int i=0;i<(len<<1);i++){
		b[i]=1ll*b[i]*c[i]%Mod;
	}
	NTT(b,-1,len<<1);
	find_dev_inv(b,len<<1);
	for(int i=n;i<(len<<1);i++){
		b[i]=0;
	}
}
void find_exp(int *a,int *b,int len){
	static int c[Maxn+5];
	if(len==1){
		b[0]=1;
		return;
	}
	find_exp(a,b,len>>1);
	find_ln(b,c,len);
	c[0]=(a[0]+1-c[0]+Mod)%Mod;
	for(int i=1;i<len;i++){
		c[i]=(a[i]-c[i]+Mod)%Mod;
	}
	for(int i=len;i<(len<<1);i++){
		b[i]=c[i]=0;
	}
	NTT(b,1,len<<1);
	NTT(c,1,len<<1);
	for(int i=0;i<(len<<1);i++){
		b[i]=1ll*b[i]*c[i]%Mod;
	}
	NTT(b,-1,len<<1);
	for(int i=len;i<(len<<1);i++){
		b[i]=c[i]=0;
	}
	for(int i=0;i<len;i++){
		printf("%d ",b[i]);
	}
	puts("");
}
int f[Maxn+5],g[Maxn+5];
int frac[Maxn+5],inv_f[Maxn+5];
void init(){
	frac[0]=1;
	for(int i=1;i<=Maxn;i++){
		frac[i]=1ll*frac[i-1]*i%Mod;
	}
	inv_f[Maxn]=quick_power(frac[Maxn],Mod-2,Mod);
	for(int i=Maxn-1;i>=0;i--){
		inv_f[i]=1ll*inv_f[i+1]*(i+1)%Mod;
	}
}
int main(){
	init();
	scanf("%s",a+1);
	scanf("%s",b+1);
	while(a[++len]!='\0');
	for(int i=1;i<=len;i++){
		if(a[i]=='1'){
			n++;
			if(b[i]=='0'){
				m++;
			}
		}
	}
	for(int i=0;i<=n-m;i++){
		f[i]=inv_f[i+1];
	}
	int len=1;
	while(len<=n-m){
		len<<=1;
	}
	find_ln(f,g,len);
	memset(f,0,sizeof f);
	for(int i=0;i<=n-m;i++){
		f[i]=1ll*g[i]*m%Mod;
	}
	memset(g,0,sizeof g);
	find_exp(f,g,len);
	int ans=0;
	for(int i=0;i<=n-m;i++){
		f[i]=g[i];
		ans=(ans+f[i])%Mod;
	}
	ans=1ll*ans*frac[m]%Mod*frac[n-m]%Mod*frac[n]%Mod;
	printf("%d\n",ans);
	return 0;
}