MTT學習小記
阿新 • • 發佈:2020-08-06
求p是1e9級別,n是1e5級別的fft
首先拆係數拆成aw+b的形式,那麼求的是(aw+b)(cw+d)=acw^2+(ad+bc)w+bd,變成求ac,ad,bc,bd的卷積
構造\(P=(a+bi)(c+di)=(ac-bd)+(ad+bc)i\),\(Q=(a-bi)(c+di)=(ac+bd)+(ad-bc)i\),求出PQ之後解方程可以解出來
觀察a+bi和a-bi是共軛的,根據共軛複數的性質(ab)'=a'b',求出a+bi的點值之後可以直接得到a-bi的點值
具體來說,\([x^i]DFT_{a+bi}(x)=[x^{N-i}]DFT_{a-bi}(x)\)(注意是N不是N-1,因為1和N-1共軛)
其實就是i和N-i的單位根也共軛
c+di直接求,再對PQ用兩次IDFT即可共四次DFT求出最終解
注意精度,所以單位根不能一個個乘過去
code
洛谷模板
#include <bits/stdc++.h> #define fo(a,b,c) for (a=b; a<=c; a++) #define fd(a,b,c) for (a=b; a>=c; a--) #define ll long long //#define file using namespace std; struct type{long double x,y;} p[262144],q[262144],F[262144]; type operator + (type a,type b) {return {a.x+b.x,a.y+b.y};} type operator - (type a,type b) {return {a.x-b.x,a.y-b.y};} type operator * (type a,type b) {return {a.x*b.x-a.y*b.y,a.x*b.y+a.y*b.x};} int N,len,n,m,mod,i,j,k,l,a[262144],b[262144],c[262144],d[262144]; ll ac[200001],ad[200001],bc[200001],bd[200001],ans[200001]; const ll w=100000; void dft(type *a,int tp) { static type A[262144]; int i,j,k,l,S=N,s1=2,s2=1; type u,v,w,W; fo(i,0,N-1) { j=i;k=0; fo(l,1,len) k=k*2+(j&1),j>>=1; A[i]=a[k]; } memcpy(a,A,sizeof(A)); fo(i,1,len) { S>>=1; fo(j,0,S-1) { fo(k,0,s2-1) { W={cos(2*M_PI*k/s1),sin(2*M_PI*k/s1)*tp}; u=a[j*s1+k],v=a[j*s1+k+s2]*W; a[j*s1+k]=u+v; a[j*s1+k+s2]=u-v; } } s1<<=1,s2<<=1; } } int main() { #ifdef file freopen("mtt.in","r",stdin); freopen("mtt.out","w",stdout); #endif scanf("%d%d%d",&n,&m,&mod);len=ceil(log2(n+m+1));N=pow(2,len); fo(i,0,n) scanf("%d",&j),a[i]=j/w,b[i]=j%w; fo(i,0,m) scanf("%d",&j),c[i]=j/w,d[i]=j%w; fo(i,0,n) p[i]={a[i],b[i]}; dft(p,1); fo(i,0,N-1) q[i]={p[(N-i)%N].x,-p[(N-i)%N].y}; fo(i,0,m) F[i]={c[i],d[i]}; dft(F,1); fo(i,0,N-1) p[i]=p[i]*F[i],q[i]=q[i]*F[i]; dft(p,-1),dft(q,-1); fo(i,0,N-1) p[i].x/=N,p[i].y/=N,q[i].x/=N,q[i].y/=N; fo(i,0,n+m) ac[i]=floor((p[i].x+q[i].x)/2+0.5),bd[i]=floor((q[i].x-p[i].x)/2+0.5),ad[i]=floor((p[i].y+q[i].y)/2+0.5),bc[i]=floor((p[i].y-q[i].y)/2+0.5); fo(i,0,n+m) ans[i]=(ac[i]%mod*w%mod*w+((ad[i]+bc[i])%mod)*w+bd[i])%mod; fo(i,0,n+m) printf("%lld ",(ans[i]+mod)%mod);printf("\n"); }