1. 程式人生 > 其它 >[AGC019E] Shuffle and Swap

[AGC019E] Shuffle and Swap

tag:組合計數,生成函式,多項式快速冪


蹲坑想出來的大體思路(霧

為了方便表述,下面這種形式

\[\begin{matrix}A_{x_1}&A_{x_2}&\cdots&A_{x_n}\\B_{y_1}&B_{y_2}&\cdots&B_{y_n}\end{matrix} \]

表示一串交換操作 \((x_1,x_2)(x_2,x_3)\cdots(x_{n-1},x_n)\)


對於一次交換操作,顯然可以把這兩列單獨拿出來(根據題目,一定滿足其中一列的 \(B=1\),另外一列的 \(A=1\)

\[\begin{matrix}x&1\\1&y\end{matrix} \]

考慮列舉另外兩個位置的所有可能性,會發現:

  • \(x=y=0\) 時,交換後 \(A\)\(B\) 相同的位置+1
  • 其餘情況,相同位置數不變

定義一串操作為有效操作,當且僅當依次執行完以後相同位置數+1。

顯然

\[\begin{matrix}0&1\\1&0\end{matrix} \]

是一串有效操作

進一步有形如:

\[\begin{matrix}0&1&\cdots&1&1\\1&1&\cdots&1&0\end{matrix} \]

都是有效操作(中間全是 \(1\)


設有 \(p\) 個位置滿足 \(A_i=1,B_i=0\)

那麼初始相同位置數為 \(k-p\)

,然後發現一個合法的操作序列一定可以分為若干個子序列(不一定連續是因為有效操作串之間互不影響),滿足剛好有 \(p\)有效操作,和一堆
\(\begin{matrix}1\\1\end{matrix}\) 操作。


對於一個包含 \(x\) 箇中間元素的有效操作串,貢獻為 \(x!\) 因為中間元素可以是任意順序。


然後就可以列舉被當作中間元素\(\begin{matrix}1\\1\end{matrix}\) 操作的個數(顯然這部分貢獻只與個數有關),然後再列舉每一種分配方案,求出貢獻和。再乘上一些組合數就可以求出答案。

\[ans=p!\cdot\sum_{x=0}^{k-p}\sum_{\sum d_i=x}(\binom{k-p}{d_1\quad d_2\quad\cdots\quad d_k}\binom{k}{d_1+1\quad d_2+1\quad\cdots\quad d_k+1}\Pi(d_i)!) \]\[ans=p!\cdot(k-p)!\cdot k!\sum_{x=0}^{k-p}\sum_{\sum d_i=x}\frac1{(d_i+1)!} \]

後面的和式可以寫成生成函式的形式

\[\sum_{i=0}^{k-p}[x^i]f^p(x) \]\[f(x)=\sum \frac1{(d_i+1)!}x^i \]

\(f^p(x)\) 可以用多項式快速冪求出。

複雜度 \(O(nlogn)\)


快速冪部分是複製的

#include<bits/stdc++.h>
using namespace std;

template<typename T>
inline void Read(T &n){
	char ch; bool flag=0;
	while(!isdigit(ch=getchar()))if(ch=='-')flag=1;
	for(n=ch^48;isdigit(ch=getchar());n=((n<<1)+(n<<3)+(ch^48))%998244353);
	if(flag)n=-n;
}

typedef long long ll;
const ll MOD = 998244353;
const ll G = 3;
const ll invG = (MOD+1)/G;
const int MAXN = 10005;

inline ll ksm(ll base, ll k=MOD-2){
    ll res=1;
    while(k){
        if(k&1)
            res=res*base%MOD;
        base=base*base%MOD;
        k>>=1ll;
    }
    return res;
}

ll invN;
int tr[MAXN<<2];
inline int Make(int len){
    int L = 1<<((int)(log(len)/log(2))+1); invN = ksm(L);
    for(register int i=0; i<L; i++) tr[i]=(tr[i>>1]>>1)|((i&1)?(L>>1):0);
    return L;
}

ll jc[MAXN], invjc[MAXN], inv[MAXN];

inline void ntt(ll *f, int len, int flag){
    for(register int i=0; i<len; i++) if(i<tr[i]) swap(f[i],f[tr[i]]);
    for(register int k=2; k<=len; k<<=1){
        int L=k>>1; ll base=ksm(flag==1?G:invG,(MOD-1)/k);
        for(register int l=0; l<len; l+=k){
            ll now=1;
            for(register int p=l; p<l+L; p++){
                ll tmp=f[p+L]*now%MOD;
                f[p+L]=f[p]-tmp; if(f[p+L]<0) f[p+L]+=MOD;
                f[p]=f[p]+tmp; if(f[p]>=MOD) f[p]-=MOD;
                now=now*base%MOD;
            }
        }
    }
    if(flag==-1) for(register int i=0; i<len; i++) f[i]=f[i]*invN%MOD;
}

inline void poly_inv(ll *f, ll *res, int len){
    static ll tmp[MAXN<<2];
    if(len==1) return void(res[0]=ksm(f[0]));
    poly_inv(f,res,(len+1)/2);
    copy(f,f+len,tmp);
    int L=Make(len+len);
    ntt(tmp,L,1); ntt(res,L,1);
    for(register int i=0; i<L; i++) res[i]=(res[i]+res[i]-tmp[i]*res[i]%MOD*res[i]%MOD+MOD)%MOD;
    ntt(res,L,-1);
    fill(res+len,res+L,0);
    fill(tmp,tmp+L,0);
}

inline void poly_ln(ll *f, ll *res, int len){
    static ll tmp[MAXN<<2];
    for(register int i=1; i<len; i++) res[i-1]=f[i]*i%MOD;
    poly_inv(f,tmp,len);
    int L=Make(len+len);
    ntt(res,L,1); ntt(tmp,L,1);
    for(register int i=0; i<L; i++) res[i]=res[i]*tmp[i]%MOD;
    ntt(res,L,-1);
    for(register int i=len-1; i>=1; i--) res[i]=res[i-1]*inv[i]%MOD; res[0]=0;
    fill(res+len,res+L,0);
    fill(tmp,tmp+L,0);
}

inline void poly_exp(ll *f, ll *res, int len){
    static ll tmp[MAXN<<2], Log[MAXN<<2];
    if(len==1) return void(res[0]=1);
    poly_exp(f,res,(len+1)/2);
    copy(f,f+len,tmp);
    poly_ln(res,Log,len);
    int L=Make(len*1.5);
    ntt(res,L,1); ntt(tmp,L,1); ntt(Log,L,1);
    for(register int i=0; i<L; i++) res[i]=(1ll-Log[i]+tmp[i]+MOD)%MOD*res[i]%MOD;
    ntt(res,L,-1);
    fill(res+len,res+L,0);
    fill(tmp,tmp+L,0);
    fill(Log,Log+L,0);
}

inline void poly_qpow(ll *f, ll *res, ll k, int len){
    static ll tmp[MAXN<<2];
    poly_ln(f,tmp,len);
    for(register int i=0; i<len; i++) tmp[i]=tmp[i]*k%MOD;
    poly_exp(tmp,res,len);
}

ll f[MAXN<<2], g[MAXN<<2];

char a[MAXN], b[MAXN];

int main(){
	// freopen("1.in","r",stdin);
	scanf("%s%s",a+1,b+1);
	int A=0, k=0, n=strlen(a+1);
	for(register int i=1; i<=n; i++)
		if(a[i]=='1') k++, A += (b[i]=='0');

    jc[0]=1; for(register int i=1; i<=n; i++) jc[i]=jc[i-1]*(ll)i%MOD, inv[i]=ksm(i);
    invjc[n]=ksm(jc[n]); invjc[0]=1;
    for(register int i=n-1; i>=1; i--) invjc[i]=invjc[i+1]*(ll)(i+1)%MOD;
	for(register int i=0; i<=k-A; i++) f[i] = 1ll*invjc[i+1]%MOD;
    poly_qpow(f,g,A,k-A+1);

	int ans=0;
	for(register int i=0; i<=k-A; i++)
		ans = (ans+1ll*jc[A]*jc[k-A]%MOD*g[i])%MOD;
	cout<<1ll*ans*jc[k]%MOD<<endl;
    return 0;
}