聯賽模擬測試5 塗色遊戲 矩陣優化DP
題目描述
分析
定義出\(dp[i][j]\)為第\(i\)列塗\(j\)種顏色的方案數
然後我們要解決幾個問題
首先是求出某一列塗恰好\(i\)種顏色的方案數\(d[i]\)
如果沒有限制必須塗\(i\)種,而是有的顏色可以不塗,那麼方案數為\(i^n\)
為了避免少塗的情況,我們減去只塗\(1 \sim i-1\)種顏色的方案數
即\(d[i]=i^n-\sum_{j=1}^{i-1}C_i^j \times d[j]\)
初始化為\(d[1]=1\)
接下來考慮轉移
\(f[i][j]=f[i-1][k] \times d[j] \times C_k^{cf} \times C_{p-k}^{j-cf}\)
其中\(i\)為當前列的編號,\(j\)為當前列選了幾種顏色,\(k\)為上一列選了幾種顏色,\(cf\)為這些顏色有幾種相同的
注意兩個組合數不能寫成\(C_j^{cf} \times C_{p-j}^{k-cf}\)
因為我們要選出\(j\)個,而不是\(k\)個
時間複雜度\(m \times n^3\)
期望得分:\(40\),實際得分:\(50\)
下一步我們考慮怎麼優化
我們會發現,如果\(j\)和\(k\)確定了,那麼\(f[i-1][k]\)乘的係數就確定了
根據乘法分配率,我們可以把係數預處理出來,優化掉一維
時間複雜度\(m \times n^2\)
期望得分:\(70\),實際得分:\(70\)
我們繼續觀察會發現,每一列的轉移乘的係數都是固定的
結合\(m\)的大小,我們可以使用矩陣快速冪優化
時間複雜度\(logm \times n^3\)
期望得分:\(100\),實際得分:\(70\)
因為出題人卡常卡到喪心病狂,最後幾個點仍然會跑到\(2s\)多
所以我們要優化程式碼的常數
能不用\(longlong\)就不用\(longlong\)
減少取模的次數
加幾個玄學的\(register\)和\(inline\)
再手動吸一下氧就可以了
時間複雜度\(logm \times n^3\)
期望得分:\(100\),實際得分:\(100\)
程式碼
#include<cstdio> #include<algorithm> #include<cmath> #include<cstring> #define fastcall __attribute__((optimize("-O3"))) %:pragma GCC optimize(2) %:pragma GCC optimize(3) %:pragma GCC optimize("Ofast") %:pragma GCC optimize("inline") %:pragma GCC optimize("-fgcse") %:pragma GCC optimize("-fgcse-lm") %:pragma GCC optimize("-fipa-sra") %:pragma GCC optimize("-ftree-pre") %:pragma GCC optimize("-ftree-vrp") %:pragma GCC optimize("-fpeephole2") %:pragma GCC optimize("-ffast-math") %:pragma GCC optimize("-fsched-spec") %:pragma GCC optimize("unroll-loops") %:pragma GCC optimize("-falign-jumps") %:pragma GCC optimize("-falign-loops") %:pragma GCC optimize("-falign-labels") %:pragma GCC optimize("-fdevirtualize") %:pragma GCC optimize("-fcaller-saves") %:pragma GCC optimize("-fcrossjumping") %:pragma GCC optimize("-fthread-jumps") %:pragma GCC optimize("-funroll-loops") %:pragma GCC optimize("-freorder-blocks") %:pragma GCC optimize("-fschedule-insns") %:pragma GCC optimize("inline-functions") %:pragma GCC optimize("-ftree-tail-merge") %:pragma GCC optimize("-fschedule-insns2") %:pragma GCC optimize("-fstrict-aliasing") %:pragma GCC optimize("-falign-functions") %:pragma GCC optimize("-fcse-follow-jumps") %:pragma GCC optimize("-fsched-interblock") %:pragma GCC optimize("-fpartial-inlining") %:pragma GCC optimize("no-stack-protector") %:pragma GCC optimize("-freorder-functions") %:pragma GCC optimize("-findirect-inlining") %:pragma GCC optimize("-fhoist-adjacent-loads") %:pragma GCC optimize("-frerun-cse-after-loop") %:pragma GCC optimize("inline-small-functions") %:pragma GCC optimize("-finline-small-functions") %:pragma GCC optimize("-ftree-switch-conversion") %:pragma GCC optimize("-foptimize-sibling-calls") %:pragma GCC optimize("-fexpensive-optimizations") %:pragma GCC optimize("inline-functions-called-once") %:pragma GCC optimize("-fdelete-null-pointer-checks") const int maxn=1e4+5; const int maxm=105; const int maxp=1e4+5; const int mod=998244353; int ny[maxn],jc[maxn],jcc[maxn],f[maxp][maxm],n,m,p,q,d[maxm],xs[maxm][maxm]; int a[maxm][maxm]; int ans; int getC(int nn,int mm){ return 1LL*jc[nn]*jcc[mm]%mod*jcc[nn-mm]%mod; } int ksm(int ds,int zs){ int ans=1; while(zs){ if(zs&1) ans=1LL*ans*ds%mod; ds=1LL*ds*ds%mod; zs>>=1; } return ans; } struct asd{ int sz[maxm][maxm]; asd(){ memset(sz,0,sizeof(sz)); } }da,xss; #define reg register asd cf(asd aa,asd bb){ asd cc; for(reg int i=1;i<maxm;i++){ for(reg int j=1;j<maxm;j++){ for(reg int k=1;k<maxm;k++){ cc.sz[i][j]=(cc.sz[i][j]+1LL*aa.sz[i][k]*bb.sz[k][j]%mod); if(cc.sz[i][j]>=mod) cc.sz[i][j]-=mod; } } } return cc; } int main(){ freopen("color.in","r",stdin); freopen("color.out","w",stdout); scanf("%d%d%d%d",&n,&m,&p,&q); ny[1]=1; for(int i=2;i<maxm;i++){ ny[i]=1LL*(mod-mod/i)*ny[mod%i]%mod; } jc[0]=jcc[0]=1; for(int i=1;i<maxm;i++){ jc[i]=1LL*jc[i-1]*i%mod; jcc[i]=1LL*jcc[i-1]*ny[i]%mod; } int mmax=std::min(n,p); d[1]=1; for(reg int i=2;i<=mmax;i++){ d[i]=ksm(i,n); for(reg int j=1;j<i;j++){ d[i]=(d[i]-1LL*d[j]*getC(i,j)%mod+mod); if(d[i]>=mod) d[i]-=mod; } } for(reg int i=1;i<=mmax;i++){ f[1][i]=1LL*getC(p,i)*d[i]%mod; da.sz[i][1]=f[1][i]; } for(reg int j=1;j<=mmax;j++){ for(reg int k=1;k<=mmax;k++){ int noww=std::min(j,k); for(reg int cf=0;cf<=noww;cf++){ if(j+k-cf<q || j+k-cf>p) continue; xs[j][k]=(xs[j][k]+1LL*d[j]*getC(k,cf)%mod*getC(p-k,j-cf)%mod); if(xs[j][k]>=mod) xs[j][k]-=mod; xss.sz[j][k]=xs[j][k]; } } } m--; while(m){ if(m&1) da=cf(xss,da); m>>=1; xss=cf(xss,xss); } for(int i=1;i<=mmax;i++){ ans+=da.sz[i][1]; if(ans>=mod) ans-=mod; } printf("%d\n",ans); return 0; }