[UOJ86]mx的組合數——NTT+數位DP+原根與指標+盧卡斯定理
阿新 • • 發佈:2019-03-02
set printf cstring false 寫法 fin ack iostream ora
對於$100\%$的數據,我們考慮優化上述$DP$,我們拿其中第一個轉移方程來說(後兩個同理),我們設$h[k]=\sum\limits_{x=0}^{p-1}[C_{x}^{a_{i+1}}==k]$。可以發現轉移可以看成是$G[j*k\ mod\ p]=\sum\limits_{j=0}^{p-1}g[j]\sum\limits_{k=0}^{p-1}h[k]$,這和卷積式子很像,但他是乘法卷積,我們想辦法將它變成加法卷積:因為$p$是質數,那麽$p$一定有原根(設為$g$),也就是說對於任意$j$,其中$1\le j\le p-1$都有指標。我們設它的指標為$ind(j)$,那麽$j*k\ mod\ p$就能轉化為$g^{(ind(j)+ind(k))\ mod\ (p-1)}\ mod\ p$。這樣我們就能用$FFT$或$NTT$來加速$DP$了,但註意到$0$沒有指標,我們在轉移時先忽略$0$,在最後輸出答案時用總個數減掉其他答案就是$\%p=0$的個數了。註意原根從$1$開始枚舉。至於$10^{30}$可以用$\_\_int128$存。時間復雜度為$O(plog_{p}^2)$。
題目鏈接:
[UOJ86]mx的組合數
題目大意:給出四個數$p,n,l,r$,對於$\forall 0\le a\le p-1$,求$l\le x\le r,C_{x}^{n}\%p=a$的$x$的數量。$p<=3000$且保證$p$是質數,$n,l,r<=10^30$。
對於$10\%$的數據,可以直接楊輝三角推。
對於$20\%$的數據,因為$n$是確定的,可以遞推出$C_{x+1}^{n}=C_{x}^{n}*\frac{x+1}{x+1-n}$。
對於另外$20\%$的數據,可以枚舉$x$然後用$lucas$定理求。
對於另外$30\%$的數據,可以想到將問題轉化成小於等於$r$的個數$-$小於等於$l-1$的個數。由$lucas$定理可知,$C_{x}^{n}\ mod\ p=\prod C_{b_{i}}^{a_{i}}\ mod\ p$,其中$a_{i},b_{i}$分別為$n,x$在$p$進制下的第$i$位。那麽我們就可以用數位$DP$求,$f[i][j]$代表從最低為開始的前$i$位,每一位的值都不大於$b_{i}$且$\%p=j$的方案數;$g[i][j]$代表從最低為開始的前$i$位,每一位的值任意且$\%p=j$的方案數。設枚舉第$i+1$位為$x$,$C_{x}^{a_{i+1}}=k$。那麽可以得到$DP$轉移方程$g[i+1][jk\ mod\ p]+=g[i][j]$,若$x<b_{i+1}$,則$f[i+1][jk\ mod\ p]+=g[i][j]$,若$x=b_{i+1}$,則$f[i+1][jk\ mod\ p]+=f[i][j]$。時間復雜度為$O(p^2log_{p})$。
對於$100\%$的數據,我們考慮優化上述$DP$,我們拿其中第一個轉移方程來說(後兩個同理),我們設$h[k]=\sum\limits_{x=0}^{p-1}[C_{x}^{a_{i+1}}==k]$。可以發現轉移可以看成是$G[j*k\ mod\ p]=\sum\limits_{j=0}^{p-1}g[j]\sum\limits_{k=0}^{p-1}h[k]$,這和卷積式子很像,但他是乘法卷積,我們想辦法將它變成加法卷積:因為$p$是質數,那麽$p$一定有原根(設為$g$),也就是說對於任意$j$,其中$1\le j\le p-1$都有指標。我們設它的指標為$ind(j)$,那麽$j*k\ mod\ p$就能轉化為$g^{(ind(j)+ind(k))\ mod\ (p-1)}\ mod\ p$。這樣我們就能用$FFT$或$NTT$來加速$DP$了,但註意到$0$沒有指標,我們在轉移時先忽略$0$,在最後輸出答案時用總個數減掉其他答案就是$\%p=0$的個數了。註意原根從$1$開始枚舉。至於$10^{30}$可以用$\_\_int128$存。時間復雜度為$O(plog_{p}^2)$。
兩種寫法,讀者自選。
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<vector> #include<bitset> #include<cstring> #include<iostream> #include<algorithm> #define ll long long typedef __int128 int128; #define MOD 998244353 using namespace std; int p; int128 l,r,n; int pr[10]; int cnt; int G; int mx; ll sum; int ind[30010]; ll f[100000]; ll g[100000]; ll h[100000]; int a[200]; int b[200]; ll ans[30010]; int c[200][30010]; int mask=1; ll s[100000]; char *p1,*p2,buf[100000]; #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++) int read_() { int x=0; char c=nc(); while(c<48) { c=nc(); } while(c>47) { x=(((x<<2)+x)<<1)+(c^48),c=nc(); } return x; } int128 read() { int128 x=0; char c=nc(); while(c<48) { c=nc(); } while(c>47) { x=(((x<<2)+x)<<1)+(c^48),c=nc(); } return x; } ll quick(int x,int y,int mod) { ll res=1ll; while(y) { if(y&1) { res=res*x%mod; } y>>=1; x=1ll*x*x%mod; } return res; } void NTT(ll *a,int len,int miku) { for(int k=0,i=0;i<len;i++) { if(i>k) { swap(a[i],a[k]); } for(int j=len>>1;(k^=j)<j;j>>=1); } for(int k=2;k<=len;k<<=1) { int t=k>>1; int x=quick(3,(MOD-1)/k,MOD); if(miku==-1) { x=quick(x,MOD-2,MOD); } for(int i=0;i<len;i+=k) { ll w=1; for(int j=i;j<i+t;j++) { ll tmp=a[j+t]*w%MOD; a[j+t]=(a[j]-tmp+MOD)%MOD; a[j]=(a[j]+tmp)%MOD; w=w*x%MOD; } } } if(miku==-1) { for(int i=0,t=quick(len,MOD-2,MOD);i<len;i++) { a[i]=a[i]*t%MOD; } } } void solve(int128 num) { memset(f,0,sizeof(f)); memset(g,0,sizeof(g)); memset(h,0,sizeof(h)); memset(a,0,sizeof(a)); int res=0; for(int i=1;num;i++) { a[i]=num%p; num/=p; res=max(res,i); } mx=max(res,mx); g[0]=f[0]=1ll; for(int k=1;k<=mx;k++) { memset(h,0,sizeof(h)); memset(s,0,sizeof(s)); NTT(g,mask,1); NTT(f,mask,1); if(a[k]>=b[k]) { h[ind[c[k][a[k]]]]++; NTT(h,mask,1); for(int i=0;i<mask;i++) { s[i]+=1ll*h[i]*f[i]%MOD; s[i]%=MOD; } NTT(h,mask,-1); h[ind[c[k][a[k]]]]--; } for(int i=b[k];i<a[k];i++) { h[ind[c[k][i]]]++; } NTT(h,mask,1); for(int i=0;i<mask;i++) { s[i]+=1ll*h[i]*g[i]%MOD; s[i]%=MOD; } NTT(h,mask,-1); NTT(s,mask,-1); memset(f,0,sizeof(f)); for(int i=0;i<mask;i++) { f[i%(p-1)]+=s[i]; f[i%(p-1)]%=MOD; } for(int i=max(b[k],a[k]);i<p;i++) { h[ind[c[k][i]]]++; } NTT(h,mask,1); for(int i=0;i<mask;i++) { s[i]=1ll*h[i]*g[i]%MOD; } NTT(s,mask,-1); memset(g,0,sizeof(g)); for(int i=0;i<mask;i++) { g[i%(p-1)]+=s[i]; g[i%(p-1)]%=MOD; } } } int main() { p=read_(),n=read(),l=read(),r=read(); l--; int s=p-1; while(mask<(p<<1)) { mask<<=1; } for(int i=2;i*i<=s;i++) { if(s%i==0) { pr[++cnt]=i; while(s%i==0) { s/=i; } } } if(s!=1) { pr[++cnt]=s; } for(int i=1;i<p;i++) { bool flag=true; for(int j=1;j<=cnt;j++) { if(quick(i,(p-1)/pr[j],p)==1) { flag=false; break; } } if(flag) { G=i; break; } } sum=1ll; for(int i=0;i<p-1;i++) { ind[sum]=i; sum*=G,sum%=p; } int128 N=n; for(int i=1;N;i++) { b[i]=N%p; N/=p; mx=max(mx,i); } for(int i=1;i<=mx;i++) { for(int j=0;j<b[i];j++) { c[i][j]=0; } sum=1ll; for(int j=b[i];j<p;j++) { c[i][j]=sum; sum*=(j+1),sum%=p; sum*=quick(j+1-b[i],p-2,p),sum%=p; } } solve(l); for(int i=0;i<p-1;i++) { ans[quick(G,i,p)]-=f[i]; } for(int i=1;i<=p-1;i++) { ans[i]=(ans[i]%MOD+MOD)%MOD; } solve(r); for(int i=0;i<p-1;i++) { ans[quick(G,i,p)]+=f[i]; } for(int i=1;i<=p-1;i++) { ans[i]%=MOD; } ans[0]=(r-l)%MOD; for(int i=1;i<p;i++) { ans[0]-=ans[i]; ans[0]=(ans[0]%MOD+MOD)%MOD; } for(int i=0;i<p;i++) { printf("%lld\n",ans[i]); } }
#include<set> #include<map> #include<queue> #include<stack> #include<cmath> #include<cstdio> #include<vector> #include<bitset> #include<cstring> #include<iostream> #include<algorithm> #define ll long long typedef __int128 int128; #define MOD 998244353 using namespace std; int p; int128 l,r,n; int pr[10]; int cnt; int G; int mx; ll sum; int ind[30010]; ll f[100000]; ll g[100000]; ll A[100000]; ll B[100000]; ll C[100000]; int a[200]; int b[200]; ll ans[30010]; int c[200][30010]; int mask=1; int s[100000]; int pw[300010]; int fac[300010]; int inv[300010]; char *p1,*p2,buf[100000]; #define nc() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++) int read_() { int x=0; char c=nc(); while(c<48) { c=nc(); } while(c>47) { x=(((x<<2)+x)<<1)+(c^48),c=nc(); } return x; } int128 read() { int128 x=0; char c=nc(); while(c<48) { c=nc(); } while(c>47) { x=(((x<<2)+x)<<1)+(c^48),c=nc(); } return x; } ll quick(int x,int y,int mod) { ll res=1ll; while(y) { if(y&1) { res=res*x%mod; } y>>=1; x=1ll*x*x%mod; } return res; } void NTT(ll *a,int len,int miku) { for(int k=0,i=0;i<len;i++) { if(i>k) { swap(a[i],a[k]); } for(int j=len>>1;(k^=j)<j;j>>=1); } for(int k=2;k<=len;k<<=1) { int t=k>>1; int x=quick(3,(MOD-1)/k,MOD); if(miku==-1) { x=quick(x,MOD-2,MOD); } for(int i=0;i<len;i+=k) { ll w=1; for(int j=i;j<i+t;j++) { ll tmp=a[j+t]*w%MOD; a[j+t]=(a[j]-tmp+MOD)%MOD; a[j]=(a[j]+tmp)%MOD; w=w*x%MOD; } } } if(miku==-1) { for(int i=0,t=quick(len,MOD-2,MOD);i<len;i++) { a[i]=a[i]*t%MOD; } } } void solve(int128 num) { memset(f,0,sizeof(f)); memset(g,0,sizeof(g)); memset(a,0,sizeof(a)); int res=0; for(int i=1;num;i++) { a[i]=num%p; num/=p; res=max(res,i); } mx=max(res,mx); g[1]=f[1]=1ll; for(int k=1;k<=mx;k++) { memset(A,0,sizeof(A)); memset(B,0,sizeof(B)); for(int i=b[k];i<p;i++) { if(c[k][i]) { A[ind[c[k][i]]]++; } } for(int i=1;i<p;i++) { B[ind[i]]+=g[i]; B[ind[i]]%=MOD; } NTT(A,mask,1); NTT(B,mask,1); for(int i=0;i<mask;i++) { C[i]=A[i]*B[i]%MOD; } NTT(C,mask,-1); memset(g,0,sizeof(g)); for(int i=0;i<mask;i++) { (g[quick(G,i%(p-1),p)]+=C[i])%=MOD; } memset(A,0,sizeof(A)); for(int i=b[k];i<a[k];i++) { if(c[k][i]) { A[ind[c[k][i]]]++; } } NTT(A,mask,1); for(int i=0;i<mask;i++) { C[i]=A[i]*B[i]%MOD; } NTT(C,mask,-1); memset(s,0,sizeof(s)); for(int i=0;i<mask;i++) { (s[quick(G,i%(p-1),p)]+=C[i])%=MOD; } if(c[k][a[k]]) { for(int i=1;i<p;i++) { (s[c[k][a[k]]*i%p]+=f[i])%=MOD;; } } for(int i=1;i<p;i++) { f[i]=s[i]; } } } int get_ori(int p) { int s=p-1; for(int i=2;i*i<=s;i++) { if(s%i==0) { pr[++cnt]=i; while(s%i==0) { s/=i; } } } if(s!=1) { pr[++cnt]=s; } for(int i=1;i<p;i++) { bool flag=true; for(int j=1;j<=cnt;j++) { if(quick(i,(p-1)/pr[j],p)==1) { flag=false; break; } } if(flag) { return i; break; } } } int main() { p=read_(),n=read(),l=read(),r=read(); while(mask<(p<<1)) { mask<<=1; } G=get_ori(p); pw[0]=1ll; for(int i=1;i<p;i++) { pw[i]=pw[i-1]*G%p; } sum=1ll; for(int i=0;i<p-1;i++) { ind[sum]=i; sum*=G,sum%=p; } int128 N=n; for(int i=1;N;i++) { b[i]=N%p; N/=p; mx=max(mx,i); } fac[0]=inv[0]=1ll; for(int i=1;i<p;i++) { fac[i]=fac[i-1]*i%p; } inv[p-1]=quick(fac[p-1],p-2,p); for(int i=p-2;i>=1;i--) { inv[i]=inv[i+1]*(i+1)%p; } for(int i=1;i<=120;i++) { for(int j=b[i];j<p;j++) { c[i][j]=fac[j]*inv[j-b[i]]%p*inv[b[i]]%p; } } solve(r); for(int i=1;i<p;i++) { ans[i]=f[i]; } solve(l-1); for(int i=1;i<p;i++) { ans[i]=((ans[i]-f[i])%MOD+MOD)%MOD; } ans[0]=(r-l+1)%MOD; for(int i=1;i<p;i++) { ans[0]=((ans[0]-ans[i])%MOD+MOD)%MOD; } for(int i=0;i<p;i++) { printf("%lld\n",ans[i]); } }
[UOJ86]mx的組合數——NTT+數位DP+原根與指標+盧卡斯定理