【[SDOI2017]序列計數】
阿新 • • 發佈:2019-01-01
感覺自己的複雜度感人
大概是\(O(p*\pi(m)+p^3logn)\)
還是能過去的
我們看到這麼大的資料範圍還是應該先想一想暴力怎麼寫
顯然我們可以直接暴力\(dp\)
設\(dp[i][j]\)表示已經選擇了\(i\)數,其中所有數的和\(mod\ p\)為\(j\)的方案數
顯然方程是
\[f[i][j]=\sum_{k=1}^mdp[i-1][((j-k)\%p+p)\%p]\]
初始的狀態是\(dp[0][0]=1\),最終的答案是\(dp[n][0]\)
至於還有一個至少有一個素數的限制條件,我們可以先不管這個條件直接算一遍,之後再保證\(k\)不為素數再算一遍,兩個一減就是答案了
這樣暴力轉移的複雜度是\(O(nmp)\)的,於是我們要考慮優化
這個轉移相當的固定,於是可以矩乘優化
我們發現因為\(p\)非常的小,於是那個膜\(p\)意義下的轉移會有很多重複的位置被轉移過去,於是我們如果可以預處理出這樣一個數組\(tot[j][k]\)
表示\(dp[i-1][k]\)會向\(dp[i][k]\)專一多少次,也就是\(dp[i][k]+=dp[i-1][j]*tot[j][k]\)
於是就有這樣一個矩陣會被構造出來
(\(p=3\)的情況)
於是就可以轉移了,至於\(tot[j][k]\)怎麼求,這個就是很簡單了
在沒有素數的情況下把所有素數對應的轉移減一遍就好了
程式碼
#include<iostream> #include<cstring> #include<bitset> #include<cstdio> #define re register #define maxn 20000005 #define LL long long const LL mod=20170408; std::bitset<maxn> f; int prime[1500000]; LL ans[101][101],a[101][101]; int m,p; LL n; inline void did_a() { LL mid[101][101]; for(re int i=1;i<=p;i++) for(re int j=1;j<=p;j++) mid[i][j]=a[i][j],a[i][j]=0; for(re int i=1;i<=p;i++) for(re int j=1;j<=p;j++) for(re int k=1;k<=p;k++) a[i][j]=(a[i][j]+(mid[i][k]*mid[k][j])%mod)%mod; } inline void did_ans() { LL mid[101][101]; for(re int i=1;i<=p;i++) for(re int j=1;j<=p;j++) mid[i][j]=ans[i][j],ans[i][j]=0; for(re int i=1;i<=p;i++) for(re int j=1;j<=p;j++) for(re int k=1;k<=p;k++) ans[i][j]=(ans[i][j]+(mid[i][k]*a[k][j])%mod)%mod; } inline void Rebuild() { memset(ans,0,sizeof(ans)); memset(a,0,sizeof(a)); for(re int i=1;i<=p;i++) ans[i][i]=1; int t=m/p; for(re int i=1;i<=p;i++) for(re int j=1;j<=p;j++) a[i][j]=t; int tot=m%p; for(re int i=1;i<=p;i++) { int cnt=tot,x=i-1; if(!x) x=p; while(cnt) { a[i][x]++,cnt--; x--; if(!x) x=p; } } } inline void out() { for(re int i=1;i<=p;i++) { for(re int j=1;j<=p;j++) printf("%d ",a[i][j]); putchar(10); } } inline void quick(LL b) { while(b) { if(b&1ll) did_ans(); b>>=1ll; did_a(); } } int main() { scanf("%lld%d%d",&n,&m,&p); f[1]=1; for(re int i=2;i<=m;i++) { if(!f[i]) prime[++prime[0]]=i; for(re int j=1;j<=prime[0]&&prime[j]*i<=m;j++) { f[prime[j]*i]=i; if(i%prime[j]==0) break; } } Rebuild(); quick(n); LL num=ans[1][1]; Rebuild(); for(re int i=1;i<=p;i++) { for(re int j=1;j<=prime[0];j++) { a[i][((i-1-prime[j])%p+p)%p+1]--; if(a[i][(i-prime[j]+p)%p+1]<0) a[i][(i-prime[j]+p)%p+1]=mod-1; } } quick(n); std::cout<<(num-ans[1][1]+mod)%mod; return 0; }