[翻譯向]階乘模大質數
本文大部分翻譯自http://min-25.hatenablog.com/entry/2017/04/10/215046,有改動。min_25牛逼
考慮經典問題:求$n!\bmod p$,p為一個大質數。
令$v=\lfloor \sqrt{n} \rfloor$,設$g_p(x)=\prod_{i=1}^p (x+i)$,那麽我們想要求$g_v(0),g_v(v)...g_v((v-1)v)$。
考慮倍增地求,假設我們有了$g_d(0),g_d(v)...g_d(dv)$,那麽$g_d$本身也唯一確定了,那麽如何求出$g_{2d}(0)g_{2d}(v)...g_{2d}(2dv)$呢?
註意到$g_{2d}(x)=g_d(x)g_d(x+d)$,那麽我們如果求出了$g_d((d+1)v),g_d((d+2)v)...g_d(2dv)$,以及$g_d(d),g_d(d+v)...g_d(2dv+d)$,那麽我們就可以直接算出$g_{2d}(0),g_{2d}(v)...g_{2d}(2dv)$。
考慮假設給定了一個多項式h的$h(0),h(1)...h(d)$,如何求$h(k),h(1+k)...h(d+k)$。
進行拉格朗日插值,那麽:
$\begin{aligned} h(m+k) &= \sum_{i=0}^{d} h(i) \prod _{j=0, i\ne j}^{d} \frac{m+k-j}{i-j} \\ &= \left(\prod_{j=0}^{d} (m+k-j)\right) \left( \sum_{i=0}^{d} \frac{h(i)}{i! (d-i)!(-1)^{d-i}} \cdot \frac{1}{m+k-i} \right) \end{aligned}$
前一個括號可以通過預處理$[k-d,k+d]$的階乘前綴積得到。後面一看就是個卷積。$O(d\log(d))$。
註意:如果$k,1+k...d+k$與$0,1,...,d$有公共部分,那麽分母會為0(雖然也可以處理,但是麻煩一點)。在本題中恰好不會。
接下來的事情就很簡單了,令$h(x)=g_d(vx)$即可。
如何求$g_{d+1}(0),g_{d+1}(v)...g_{d+1}((d+1)v)$呢?前面到dv為止暴力添上一項,最後一項暴力。
這個做法復雜度是$O(\sqrt{n}\log(n))$,可以用威爾遜定理減小常數。比起傳統的$O(\sqrt{n}\log^{1.5}(n)) \sim O(\sqrt{n}\log^2(n))$多點求值做法,這個做法常數小,還好寫,不知道高到哪裏去了。
測速:51nod 1387。如果n為偶數,答案為n!,否則為$\frac{n!}{2}$。這題有一個坑點在於fft長度只能到$2^{16}$,超出的需要拆成兩段(雖然好像不會變慢就是了)。
#define SZ 666666 int MOD; ll w[2][SZ],G,fac[SZ],rfac[SZ]; inline ll qp(ll a,ll b) { ll ans=1; a%=MOD; while(b) { if(b&1) ans=ans*a%MOD; a=a*a%MOD; b>>=1; } return ans; } inline ll org_root() { static ll yss[2333]; int yyn=0; ll xp=MOD-1; for(ll i=2;i*i<=xp;i++) { if(xp%i) continue; yss[++yyn]=i; while(xp%i==0) xp/=i; } if(xp!=1) yss[++yyn]=xp; ll ans=1; while(1) { bool ok=1; for(int i=1;i<=yyn;i++) if(qp(ans,(MOD-1)/yss[i])==1) {ok=0; break;} if(ok) return ans; ++ans; } } int K; ll rv; inline void fftinit(int n) { for(K=1;K<n;K<<=1); w[0][0]=w[0][K]=1; ll g=qp(G,(MOD-1)/K); for(int i=1;i<K;i++) w[0][i]=w[0][i-1]*g%MOD; for(int i=0;i<=K;i++) w[1][i]=w[0][K-i]; rv=qp(K,MOD-2); } inline void fft(int* x,int v) { for(int i=0;i<K;i++) (x[i]<0)?(x[i]+=MOD):0; for(int i=0,j=0;i<K;i++) { if(i>j) swap(x[i],x[j]); for(int l=K>>1;(j^=l)<l;l>>=1); } for(int i=2;i<=K;i<<=1) for(int l=0;l<i>>1;l++) { register int W=w[v][K/i*l],*p=x+l+(i>>1),*q=x+l,t; for(register int j=0;j<K;j+=i) { p[j]=(q[j]-(t=(ll)p[j]*W%MOD)<0)?(q[j]-t+MOD):(q[j]-t); q[j]=(q[j]>=MOD-t)?(q[j]-MOD+t):(q[j]+t); } } if(!v) return; for(int i=0;i<K;i++) x[i]=x[i]*rv%MOD; } ll ff[SZ]; int A[SZ],B[SZ],C[SZ]; inline void calc(int*h,int*o,int d,int k) { int off=k-d-1; ff[0]=1; for(int j=1;j<=d+d+3;++j) { int s=(j+off)%MOD; if(s<0) s+=MOD; ff[j]=ff[j-1]*(ll)s%MOD; } fftinit(d+d+d+5); memset(A,0,sizeof(A[0])*K); memset(B,0,sizeof(B[0])*K); for(int i=0;i<=d;++i) { A[i]=h[i]*(ll)rfac[i]%MOD*rfac[d-i]%MOD; if((d-i)&1) A[i]=(MOD-A[i])%MOD; } for(int i=0;i<=d+d;++i) B[i]=qp(i-d+k,MOD-2); if(K<=(1<<16)) { fft(A,0); fft(B,0); for(int i=0;i<K;++i) C[i]=(ll)A[i]*B[i]%MOD; fft(C,1); } else { fftinit(K>>1); fft(A,0); fft(A+K,0); fft(B,0); fft(B+K,0); for(int i=0;i<K;++i) C[i+K]=(A[i]*(ll)B[i+K]+(ll)B[i]*A[i+K])%MOD; for(int i=0;i<K;++i) C[i]=A[i]*(ll)B[i]%MOD; fft(C,1); fft(C+K,1); } for(int i=0;i<=d;++i) { //i+k-d...i+k o[i]=C[i+d]*ff[i+k-off]%MOD *qp(ff[i+k-d-off-1],MOD-2)%MOD; (o[i]<0)?(o[i]+=MOD):0; } } int V; ll rV; int aa[SZ],bb[SZ]; inline void work(int x,vector<int>&v) { if(x==0) {v.pb(1); return;} if(x&1) { work(x-1,v); for(int i=0;i<v.size();++i) v[i]=(ll)v[i]*(i*V+x)%MOD; ll p=1; for(int i=1;i<=x;++i) p=p*(V*x+i)%MOD; v.pb(p); return; } int d=x>>1; work(d,v); for(int i=0;i<=d;++i) aa[i]=v[i]; calc(aa,aa+d+1,d,d+1); calc(aa,bb,d+d,d*rV%MOD); v.resize(x+1); for(int i=0;i<=x;++i) v[i]=aa[i]*(ll)bb[i]%MOD; } inline ll gfac_(int x) { V=sqrt(x); rV=qp(V,MOD-2); vector<int> tmp; work(V,tmp); ll ans=1; for(int i=0;i<V;++i) ans=ans*tmp[i]%MOD; for(int i=V*V+1;i<=x;++i) ans=ans*i%MOD; return ans; } inline ll gfac(int x) { if(x>=MOD) return 0; if(x>MOD-x-1) { int s=qp(gfac(MOD-x-1),MOD-2); if(x%2);else s=-s; return s; } return gfac_(x); } int main() { int n; scanf("%d%d",&n,&MOD); fac[0]=1; for(int i=1;i<SZ;++i) fac[i]=fac[i-1]*i%MOD; rfac[SZ-1]=qp(fac[SZ-1],MOD-2); for(int i=SZ-1;i>=1;--i) rfac[i-1]=rfac[i]*i%MOD; G=org_root(); int ans=gfac(n)%MOD; if(n&1) ans=ans*(ll)((MOD+1)/2)%MOD; ans%=MOD; if(ans<0) ans+=MOD; printf("%d\n",int(ans)); }
[翻譯向]階乘模大質數