1. 程式人生 > 其它 >21.5.13 t2

21.5.13 t2

tag:揹包dp,數論


首先可以把給定的排列分成若干迴圈,將長度相同的分為一組,則可以分別處理每組然後乘起來。

對於一組數量為 \(cnt_a\) 長度為 \(a\) 的迴圈,再分成若干組,假設其中一組有 \(b\) 個,則必須滿足 \(\gcd(ab,k)=b\)。而這樣一組的貢獻為 \((b-1)!a^{b-1}\)

所以相當於是將 \(cnt_a\) 劃分成若干整數 \(b_i\),滿足 \(\gcd(ab_i,k)=b_i\),然後一種劃分方案的貢獻為 \(\binom{cnt_a}{b_1\ \cdots\ b_k}\Pi(b_i-1)!a^{b_i-1}=cnt_a!\Pi\frac 1{b_i}a^{b_i-1}\)

所以想到一種做法,設 \(f_i\) 表示前 \(i\) 個分成若干組,然後可以列舉最後一組分多少個進行dp。

考慮它的複雜度,為 \(O(n*\)合法的\(b\)的個數\()\),實際上是 \(\sigma_0(k)\) 級別的,然後跑得飛快(卡常是在輸入部分)


簡易證明:

分別考慮每一個質數 \(p\),設 \(a,b,k\) 分別包含 \(n_a,n_b,n_k\)\(p\)

則有 \(\min(n_a+n_b,n_k)=n_b\)

  • \(n_a=0\),則 \(n_b\le n_k\)
  • \(n_a\not=0\),則 \(n_b=n_k\)

所以 \(b|k\)


#include<bits/stdc++.h>
using namespace std;

namespace IO {
    #define getc() p1==p2&&(p2=(p1=buf)+fread(buf,1,1<<21,stdin),p1==p2)?EOF:*p1++
    //#define getc() getchar()
    char buf[1<<21],*p1,*p2,ch;
    void rd(int &x){
        x=0;char c=getc();
        while(c<48||c>57) c=getc();
        while(c>=48&&c<=57) x=x*10+c-48,c=getc();
    }
}
using IO::rd;

template<typename T>
inline void Read(T &n){
    char ch; bool flag=false;
    while(!isdigit(ch=getchar())) if(ch=='-')flag=true;
    for(n=ch^48; isdigit(ch=getchar()); n=(n<<1)+(n<<3)+(ch^48));
    if(flag) n=-n;
}

#define Read rd

enum{
    MAXN = 10000005,
    MOD = 998244353
};

inline int ksm(int base, int k=MOD-2){
    int res=1;
    while(k){
        if(k&1)
            res = 1ll*res*base%MOD;
        base = 1ll*base*base%MOD;
        k >>= 1;
    }
    return res;
}

inline void upd(int &a, long long b){a = (a+b)%MOD;}

int n, k;
int a[MAXN];
char vis[MAXN];

int cnt[MAXN], inv[MAXN];

int q[MAXN], val[MAXN], f[MAXN], jc[MAXN];

int gcd(int a, int b){return b?gcd(b,a%b):a;}
inline int calc(int len){
    int top=0, num = cnt[len];
    for(register int i=1; i<=k; i++) if(k%i==0 and gcd(len,k/i)==1) q[++top] = i, val[top] = ksm(len,i-1);
    f[0] = 1; int tp = 0; q[top+1] = 1e9;
    for(register int i=1; i<=num; i++){
        f[i] = 0;
        for(register int j=1; q[j]<=i; j++)
            f[i] = (f[i]+1ll*f[i-q[j]]*val[j])%MOD;
        f[i] = 1ll*f[i]*inv[i]%MOD;
    }
    return 1ll*f[num]*jc[num]%MOD;
}

int main(){
    // freopen("3.in","r",stdin);
    // freopen("3.out","w",stdout);
    Read(n); Read(k);
    for(register int i=1; i<=n; i++) Read(a[i]);
    jc[0] = 1; for(register int i=1; i<=n; i++) jc[i] = 1ll*jc[i-1]*i%MOD;
    inv[1] = 1; for(register int i=2; i<=n; i++) inv[i] = 1ll*(MOD-MOD/i)*inv[MOD%i]%MOD;
    for(register int i=1; i<=n; i++) if(!vis[i]){
        int len=1, x=i; vis[i] = 1;
        while(!vis[a[x]]) vis[x = a[x]] = true, len++;
        cnt[len]++;
    }
    int ans=1;
    for(register int i=1; i<=n; i++) if(cnt[i]) ans = 1ll*ans*calc(i)%MOD;
    cout<<ans<<'\n';
    return 0;
}