1. 程式人生 > >Codeforces 1085G(1086E) Beautiful Matrix $dp$+樹狀陣列

Codeforces 1085G(1086E) Beautiful Matrix $dp$+樹狀陣列

題意

定義一個\(n*n\)的矩陣是\(beautiful\)的,需要滿足以下三個條件:

1.每一行是一個排列。

2.上下相鄰的兩個元素的值不同。

再定義兩個矩陣的字典序大的矩陣大。

給出一個\(beautiful\)\(n*n\)的矩陣,求有多少個矩陣小於這個矩陣。

Solution

逐行計算。

\(ans=\)每行字典序比這行小的排列且與上一行相鄰的兩個元素值不同的排列個數*\(n\)個元素錯排的方案數\(^{n-i}\)

第一行的方案數隨便算,我就不說了。

另外的行大概就是逐位算。

從後往前列舉前\(i\)個數相同,樹狀陣列維護當前位置可以填的數有幾個有限制(即上一行後\(n-i+1\)

中有這個數)和當前能填哪些數(即比\(a_{i,j}\)小且在當前行後\(n-i+1\)個數中出現了),不難發現有限制的數或者沒限制的數都是同質的,那麼就可以乘法原理算了,問題就是有幾個數有限制的錯排怎麼算方案數?\(dp\)一下就好了。

\(dp_{i,j}\)表示\(i\)個數中有\(j\)個數有限制的排列的方案數。

考慮從\(dp_{i,j-1}\)轉移,減去多了一個限制的數會少的方案數。

多了一個限制的數不合法的方案數?那我們就強制多的那個數不符合限制,另外數符合限制,也就是\(dp_{i-1,j-1}\)

\(dp_{i,j}=dp_{i,j-1}-dp_{i-1,j-1}\)

如果不會推,也可以打表

\(dp_{n,n}\)的值就是\(n\)個數錯排的方案數。

#include<bits/stdc++.h>
#define For(i,x,y) for (register int i=(x);i<=(y);i++)
#define Dow(i,x,y) for (register int i=(x);i>=(y);i--)
#define cross(i,k) for (register int i=first[k];i;i=last[i])
#define Debug(x) cerr<<#x<<"="<<(x)<<endl
#define mp make_pair
#define fi first
#define se second
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
typedef pair<int,int> pa;
inline ll read(){
    ll x=0;int ch=getchar(),f=1;
    while (!isdigit(ch)&&(ch!='-')&&(ch!=EOF)) ch=getchar();
    if (ch=='-'){f=-1;ch=getchar();}
    while (isdigit(ch)){x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}
    return x*f;
}
const int N = 2010;
int n,a[N][N];
const int mod = 998244353;
int fac[N],dp[N][N],p[N];
inline void init(){
    fac[0]=1;For(i,1,n) fac[i]=1ll*fac[i-1]*i%mod;
    dp[1][0]=1;
    For(i,2,n){
        dp[i][0]=fac[i];
        For(j,1,i) dp[i][j]=(dp[i][j-1]-dp[i-1][j-1]+mod)%mod;
    }
    p[0]=1;For(i,1,n) p[i]=1ll*p[i-1]*dp[n][n]%mod;
}
struct BIT{
    int c[N],sum;
    inline void clear(){sum=0,memset(c,0,sizeof c);}
    inline void Add(int x){sum++;for (;x<=n;x+=x&-x) c[x]++;}
    inline int Query(int x){int ans=0;for (;x;x-=x&-x) ans+=c[x];return ans;}
}t,T;
int b[N],ans;
inline void Add(int x){if (++b[x]==2) T.Add(x);}
inline void upd(int &x,int y){x+=y,(x>=mod)?x-=mod:0;}
int main(){
    n=read();
    For(i,1,n) For(j,1,n) a[i][j]=read();
    init();int sum=0;
    For(i,1,n) upd(sum,1ll*fac[n-i]*(a[1][i]-1-t.Query(a[1][i]-1))%mod),t.Add(a[1][i]);
    ans=1ll*sum*p[n-1]%mod;//printf("%d\n",ans);
    For(i,2,n){
        t.clear(),T.clear(),sum=0,memset(b,0,sizeof b);
        Dow(j,n,1){
            Add(a[i][j]),Add(a[i-1][j]),t.Add(a[i][j]);
            int x=T.Query(a[i][j]-1),y=t.Query(a[i][j]-1)-x,z=T.sum;
            if (b[a[i-1][j]]==2&&a[i-1][j]<a[i][j]) x--;
            if (b[a[i-1][j]]==2) z--;
            upd(sum,1ll*x*dp[n-j][z-1]%mod),upd(sum,1ll*y*dp[n-j][z]%mod);
            //printf("%d %d  ",z,x*dp[n-j][z-1]);
        }//puts("");
        upd(ans,1ll*sum*p[n-i]%mod);
    }
    printf("%d\n",ans);
}