1. 程式人生 > 實用技巧 >聯賽模擬測試5 塗色遊戲 矩陣優化DP

聯賽模擬測試5 塗色遊戲 矩陣優化DP

題目描述

分析

定義出\(dp[i][j]\)為第\(i\)列塗\(j\)種顏色的方案數

然後我們要解決幾個問題

首先是求出某一列塗恰好\(i\)種顏色的方案數\(d[i]\)

如果沒有限制必須塗\(i\)種,而是有的顏色可以不塗,那麼方案數為\(i^n\)

為了避免少塗的情況,我們減去只塗\(1 \sim i-1\)種顏色的方案數

\(d[i]=i^n-\sum_{j=1}^{i-1}C_i^j \times d[j]\)

初始化為\(d[1]=1\)

接下來考慮轉移

\(f[i][j]=f[i-1][k] \times d[j] \times C_k^{cf} \times C_{p-k}^{j-cf}\)


其中\(i\)為當前列的編號,\(j\)為當前列選了幾種顏色,\(k\)為上一列選了幾種顏色,\(cf\)為這些顏色有幾種相同的
注意兩個組合數不能寫成\(C_j^{cf} \times C_{p-j}^{k-cf}\)

因為我們要選出\(j\)個,而不是\(k\)

時間複雜度\(m \times n^3\)

期望得分:\(40\),實際得分:\(50\)

下一步我們考慮怎麼優化

我們會發現,如果\(j\)\(k\)確定了,那麼\(f[i-1][k]\)乘的係數就確定了

根據乘法分配率,我們可以把係數預處理出來,優化掉一維

時間複雜度\(m \times n^2\)

期望得分:\(70\),實際得分:\(70\)

我們繼續觀察會發現,每一列的轉移乘的係數都是固定的

結合\(m\)的大小,我們可以使用矩陣快速冪優化

時間複雜度\(logm \times n^3\)
期望得分:\(100\),實際得分:\(70\)

因為出題人卡常卡到喪心病狂,最後幾個點仍然會跑到\(2s\)

所以我們要優化程式碼的常數
能不用\(longlong\)就不用\(longlong\)

減少取模的次數
加幾個玄學的\(register\)\(inline\)

再手動吸一下氧就可以了
時間複雜度\(logm \times n^3\)
期望得分:\(100\),實際得分:\(100\)

程式碼

#include<cstdio>
#include<algorithm>
#include<cmath>
#include<cstring>
#define fastcall __attribute__((optimize("-O3")))
%:pragma GCC optimize(2)
%:pragma GCC optimize(3)
%:pragma GCC optimize("Ofast")
%:pragma GCC optimize("inline")
%:pragma GCC optimize("-fgcse")
%:pragma GCC optimize("-fgcse-lm")
%:pragma GCC optimize("-fipa-sra")
%:pragma GCC optimize("-ftree-pre")
%:pragma GCC optimize("-ftree-vrp")
%:pragma GCC optimize("-fpeephole2")
%:pragma GCC optimize("-ffast-math")
%:pragma GCC optimize("-fsched-spec")
%:pragma GCC optimize("unroll-loops")
%:pragma GCC optimize("-falign-jumps")
%:pragma GCC optimize("-falign-loops")
%:pragma GCC optimize("-falign-labels")
%:pragma GCC optimize("-fdevirtualize")
%:pragma GCC optimize("-fcaller-saves")
%:pragma GCC optimize("-fcrossjumping")
%:pragma GCC optimize("-fthread-jumps")
%:pragma GCC optimize("-funroll-loops")
%:pragma GCC optimize("-freorder-blocks")
%:pragma GCC optimize("-fschedule-insns")
%:pragma GCC optimize("inline-functions")
%:pragma GCC optimize("-ftree-tail-merge")
%:pragma GCC optimize("-fschedule-insns2")
%:pragma GCC optimize("-fstrict-aliasing")
%:pragma GCC optimize("-falign-functions")
%:pragma GCC optimize("-fcse-follow-jumps")
%:pragma GCC optimize("-fsched-interblock")
%:pragma GCC optimize("-fpartial-inlining")
%:pragma GCC optimize("no-stack-protector")
%:pragma GCC optimize("-freorder-functions")
%:pragma GCC optimize("-findirect-inlining")
%:pragma GCC optimize("-fhoist-adjacent-loads")
%:pragma GCC optimize("-frerun-cse-after-loop")
%:pragma GCC optimize("inline-small-functions")
%:pragma GCC optimize("-finline-small-functions")
%:pragma GCC optimize("-ftree-switch-conversion")
%:pragma GCC optimize("-foptimize-sibling-calls")
%:pragma GCC optimize("-fexpensive-optimizations")
%:pragma GCC optimize("inline-functions-called-once")
%:pragma GCC optimize("-fdelete-null-pointer-checks")
const int maxn=1e4+5;
const int maxm=105;
const int maxp=1e4+5;
const int mod=998244353;
int ny[maxn],jc[maxn],jcc[maxn],f[maxp][maxm],n,m,p,q,d[maxm],xs[maxm][maxm];
int a[maxm][maxm];
int ans;
int getC(int nn,int mm){
	return 1LL*jc[nn]*jcc[mm]%mod*jcc[nn-mm]%mod;
}
int ksm(int ds,int zs){
	int ans=1;
	while(zs){
		if(zs&1) ans=1LL*ans*ds%mod;
		ds=1LL*ds*ds%mod;
		zs>>=1;
	}
	return ans;
}
struct asd{
	int sz[maxm][maxm];
	asd(){
		memset(sz,0,sizeof(sz));
	}
}da,xss;
#define reg register 
asd cf(asd aa,asd bb){
	asd cc;
	for(reg int i=1;i<maxm;i++){
		for(reg int j=1;j<maxm;j++){
			for(reg int k=1;k<maxm;k++){
				cc.sz[i][j]=(cc.sz[i][j]+1LL*aa.sz[i][k]*bb.sz[k][j]%mod);
				if(cc.sz[i][j]>=mod) cc.sz[i][j]-=mod;
			}
		}
	}
	return cc;
}
int main(){
	freopen("color.in","r",stdin);
	freopen("color.out","w",stdout);
	scanf("%d%d%d%d",&n,&m,&p,&q);
	ny[1]=1;
	for(int i=2;i<maxm;i++){
		ny[i]=1LL*(mod-mod/i)*ny[mod%i]%mod;
	}
	jc[0]=jcc[0]=1;
	for(int i=1;i<maxm;i++){
		jc[i]=1LL*jc[i-1]*i%mod;
		jcc[i]=1LL*jcc[i-1]*ny[i]%mod;
	}
	int mmax=std::min(n,p);
	d[1]=1;
	for(reg int i=2;i<=mmax;i++){
		d[i]=ksm(i,n);
		for(reg int j=1;j<i;j++){
			d[i]=(d[i]-1LL*d[j]*getC(i,j)%mod+mod);
			if(d[i]>=mod) d[i]-=mod;
		}
	}
	for(reg int i=1;i<=mmax;i++){
		f[1][i]=1LL*getC(p,i)*d[i]%mod;
		da.sz[i][1]=f[1][i];
	}
	for(reg int j=1;j<=mmax;j++){
		for(reg int k=1;k<=mmax;k++){
			int noww=std::min(j,k);
			for(reg int cf=0;cf<=noww;cf++){
				if(j+k-cf<q || j+k-cf>p) continue;
				xs[j][k]=(xs[j][k]+1LL*d[j]*getC(k,cf)%mod*getC(p-k,j-cf)%mod);
				if(xs[j][k]>=mod) xs[j][k]-=mod;
				xss.sz[j][k]=xs[j][k];
			}
		}
	}
	m--;
	while(m){
		if(m&1) da=cf(xss,da);
		m>>=1;
		xss=cf(xss,xss);
	}
	for(int i=1;i<=mmax;i++){
		ans+=da.sz[i][1];
		if(ans>=mod) ans-=mod;
	}
	printf("%d\n",ans);
	return 0;
}