[AGC019E] Shuffle and Swap
阿新 • • 發佈:2021-06-26
tag:組合計數,生成函式,多項式快速冪
蹲坑想出來的大體思路(霧
為了方便表述,下面這種形式
\[\begin{matrix}A_{x_1}&A_{x_2}&\cdots&A_{x_n}\\B_{y_1}&B_{y_2}&\cdots&B_{y_n}\end{matrix} \]表示一串交換操作 \((x_1,x_2)(x_2,x_3)\cdots(x_{n-1},x_n)\)
對於一次交換操作,顯然可以把這兩列單獨拿出來(根據題目,一定滿足其中一列的 \(B=1\),另外一列的 \(A=1\))
\[\begin{matrix}x&1\\1&y\end{matrix} \]考慮列舉另外兩個位置的所有可能性,會發現:
- \(x=y=0\) 時,交換後 \(A\) 和 \(B\) 相同的位置+1
- 其餘情況,相同位置數不變
定義一串操作為有效操作,當且僅當依次執行完以後相同位置數+1。
顯然
\[\begin{matrix}0&1\\1&0\end{matrix} \]是一串有效操作
進一步有形如:
\[\begin{matrix}0&1&\cdots&1&1\\1&1&\cdots&1&0\end{matrix} \]都是有效操作(中間全是 \(1\))
設有 \(p\) 個位置滿足 \(A_i=1,B_i=0\)
那麼初始相同位置數為 \(k-p\)
\(\begin{matrix}1\\1\end{matrix}\) 操作。
對於一個包含 \(x\) 箇中間元素的有效操作串,貢獻為 \(x!\) 因為中間元素可以是任意順序。
然後就可以列舉被當作中間元素的 \(\begin{matrix}1\\1\end{matrix}\) 操作的個數(顯然這部分貢獻只與個數有關),然後再列舉每一種分配方案,求出貢獻和。再乘上一些組合數就可以求出答案。
\[ans=p!\cdot\sum_{x=0}^{k-p}\sum_{\sum d_i=x}(\binom{k-p}{d_1\quad d_2\quad\cdots\quad d_k}\binom{k}{d_1+1\quad d_2+1\quad\cdots\quad d_k+1}\Pi(d_i)!) \]\[ans=p!\cdot(k-p)!\cdot k!\sum_{x=0}^{k-p}\sum_{\sum d_i=x}\frac1{(d_i+1)!} \]後面的和式可以寫成生成函式的形式
\(f^p(x)\) 可以用多項式快速冪求出。
複雜度 \(O(nlogn)\)
快速冪部分是複製的
#include<bits/stdc++.h>
using namespace std;
template<typename T>
inline void Read(T &n){
char ch; bool flag=0;
while(!isdigit(ch=getchar()))if(ch=='-')flag=1;
for(n=ch^48;isdigit(ch=getchar());n=((n<<1)+(n<<3)+(ch^48))%998244353);
if(flag)n=-n;
}
typedef long long ll;
const ll MOD = 998244353;
const ll G = 3;
const ll invG = (MOD+1)/G;
const int MAXN = 10005;
inline ll ksm(ll base, ll k=MOD-2){
ll res=1;
while(k){
if(k&1)
res=res*base%MOD;
base=base*base%MOD;
k>>=1ll;
}
return res;
}
ll invN;
int tr[MAXN<<2];
inline int Make(int len){
int L = 1<<((int)(log(len)/log(2))+1); invN = ksm(L);
for(register int i=0; i<L; i++) tr[i]=(tr[i>>1]>>1)|((i&1)?(L>>1):0);
return L;
}
ll jc[MAXN], invjc[MAXN], inv[MAXN];
inline void ntt(ll *f, int len, int flag){
for(register int i=0; i<len; i++) if(i<tr[i]) swap(f[i],f[tr[i]]);
for(register int k=2; k<=len; k<<=1){
int L=k>>1; ll base=ksm(flag==1?G:invG,(MOD-1)/k);
for(register int l=0; l<len; l+=k){
ll now=1;
for(register int p=l; p<l+L; p++){
ll tmp=f[p+L]*now%MOD;
f[p+L]=f[p]-tmp; if(f[p+L]<0) f[p+L]+=MOD;
f[p]=f[p]+tmp; if(f[p]>=MOD) f[p]-=MOD;
now=now*base%MOD;
}
}
}
if(flag==-1) for(register int i=0; i<len; i++) f[i]=f[i]*invN%MOD;
}
inline void poly_inv(ll *f, ll *res, int len){
static ll tmp[MAXN<<2];
if(len==1) return void(res[0]=ksm(f[0]));
poly_inv(f,res,(len+1)/2);
copy(f,f+len,tmp);
int L=Make(len+len);
ntt(tmp,L,1); ntt(res,L,1);
for(register int i=0; i<L; i++) res[i]=(res[i]+res[i]-tmp[i]*res[i]%MOD*res[i]%MOD+MOD)%MOD;
ntt(res,L,-1);
fill(res+len,res+L,0);
fill(tmp,tmp+L,0);
}
inline void poly_ln(ll *f, ll *res, int len){
static ll tmp[MAXN<<2];
for(register int i=1; i<len; i++) res[i-1]=f[i]*i%MOD;
poly_inv(f,tmp,len);
int L=Make(len+len);
ntt(res,L,1); ntt(tmp,L,1);
for(register int i=0; i<L; i++) res[i]=res[i]*tmp[i]%MOD;
ntt(res,L,-1);
for(register int i=len-1; i>=1; i--) res[i]=res[i-1]*inv[i]%MOD; res[0]=0;
fill(res+len,res+L,0);
fill(tmp,tmp+L,0);
}
inline void poly_exp(ll *f, ll *res, int len){
static ll tmp[MAXN<<2], Log[MAXN<<2];
if(len==1) return void(res[0]=1);
poly_exp(f,res,(len+1)/2);
copy(f,f+len,tmp);
poly_ln(res,Log,len);
int L=Make(len*1.5);
ntt(res,L,1); ntt(tmp,L,1); ntt(Log,L,1);
for(register int i=0; i<L; i++) res[i]=(1ll-Log[i]+tmp[i]+MOD)%MOD*res[i]%MOD;
ntt(res,L,-1);
fill(res+len,res+L,0);
fill(tmp,tmp+L,0);
fill(Log,Log+L,0);
}
inline void poly_qpow(ll *f, ll *res, ll k, int len){
static ll tmp[MAXN<<2];
poly_ln(f,tmp,len);
for(register int i=0; i<len; i++) tmp[i]=tmp[i]*k%MOD;
poly_exp(tmp,res,len);
}
ll f[MAXN<<2], g[MAXN<<2];
char a[MAXN], b[MAXN];
int main(){
// freopen("1.in","r",stdin);
scanf("%s%s",a+1,b+1);
int A=0, k=0, n=strlen(a+1);
for(register int i=1; i<=n; i++)
if(a[i]=='1') k++, A += (b[i]=='0');
jc[0]=1; for(register int i=1; i<=n; i++) jc[i]=jc[i-1]*(ll)i%MOD, inv[i]=ksm(i);
invjc[n]=ksm(jc[n]); invjc[0]=1;
for(register int i=n-1; i>=1; i--) invjc[i]=invjc[i+1]*(ll)(i+1)%MOD;
for(register int i=0; i<=k-A; i++) f[i] = 1ll*invjc[i+1]%MOD;
poly_qpow(f,g,A,k-A+1);
int ans=0;
for(register int i=0; i<=k-A; i++)
ans = (ans+1ll*jc[A]*jc[k-A]%MOD*g[i])%MOD;
cout<<1ll*ans*jc[k]%MOD<<endl;
return 0;
}