多項式模板QAQ
阿新 • • 發佈:2018-11-15
失蹤人口迴歸?
似乎抄了挺多天的多項式題目程式碼。。留幾個模板吧。。
(都不知道是不是好模板。。
反正也是自己看
最長不超過80行我去
UOJ #34. 多項式乘法
#include <bits/stdc++.h>
#define me(a,x) memset(a,x,sizeof a)
using namespace std;
const int N=3e5+2,inf=1e9+7;
const double pi=acos(-1);
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
int x=0,f=1; char ch=gc;
while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
return x*f;
}
struct P{
double x,y;
P(){x=y=0;}
P(double a,double b){x=a,y=b;}
}a[N],b[N],c[N];
P operator+(P x ,P y){return P(x.x+y.x,x.y+y.y);}
P operator-(P x,P y){return P(x.x-y.x,x.y-y.y);}
P operator*(P x,P y){return P(x.x*y.x-x.y*y.y,x.y*y.x+x.x*y.y);}
int id[N],an,bn,cn,n,ln;
void fft(P *s,int si){
for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
for(int i=1;i<n;i<<=1){
P wn(cos (pi/i),si*sin(pi/i));
for(int j=0;j<n;j+=i<<1){
P e(1,0),*b=s+j,*c=b+i;
for(int k=0;k<i;++k,e=e*wn){
P x=b[k],y=c[k]*e;
b[k]=x+y,c[k]=x-y;
}
}
}
if(si<0) for(int i=0;i<n;++i) s[i].x/=n;
}
int main(){
an=read()+1,bn=read()+1;
for(int i=0;i<an;++i) a[i].x=read();
for(int i=0;i<bn;++i) b[i].x=read();
n=1,ln=0; while(n<an+bn) n<<=1,++ln;
for(int i=0;i<n;++i) id[i]=id[i>>1]>>1 | ((i&1)<<(ln-1));
fft(a,1); fft(b,1);
for(int i=0;i<n;++i) c[i]=a[i]*b[i];
fft(c,-1);
printf("%d",int(c[0].x+0.5));
for(int i=1;i<an+bn-1;++i) printf(" %d",int(c[i].x+0.5));
puts("");
return 0;
}
BZOJ 3992 原根的應用(還是生成函式?)+NTT
#include<bits/stdc++.h>
using namespace std;
const int N=16385,Mod=1004535809;
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
int x=0,f=1; char ch=gc;
while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
return x*f;
}
int pw(int x,int k,int mod){
int r=1;
for(;k;k>>=1,x=1ll*x*x%mod) if(k&1)r=1ll*r*x%mod;
return r;
}
int id[N],an,n,m,ln,a[N],b[N],ans[N],c[N/2],v[N/2],ti,ny;
void ntt(int *s,int si){
for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
for(int i=1;i<n;i<<=1){
int wn=pw(3,si==1?(Mod-1)/i/2:Mod-1-(Mod-1)/i/2,Mod);
for(int j=0;j<n;j+=i<<1){
int e=1,*b=s+j,*c=b+i;
for(int k=0;k<i;++k,e=1ll*e*wn%Mod){
int x=b[k],y=1ll*c[k]*e%Mod;
b[k]=(x+y)%Mod,c[k]=(x-y)%Mod;
}
}
}
if(si<0) for(int i=0;i<n;++i) s[i]=1ll*s[i]*ny%Mod;
}
void mul(int *a,int *bb){
for(int i=0;i<n;++i) b[i]=bb[i];
ntt(a,1),ntt(b,1);
for(int i=0;i<n;++i) a[i]=1ll*a[i]*b[i]%Mod;
ntt(a,-1);
for(int i=m-1;i<n;++i) a[i-m+1]=(a[i-m+1]+a[i])%Mod,a[i]=0;
}
bool check(int x,int m){
int u=1; ++ti;
for(int i=1;i<m;++i,u=u*x%m){
if(v[u]==ti)return 0; v[u]=ti;
}
return 1;
}
int get(int m){for(int i=2;i<=m;++i) if(check(i,m)) return i;}
int main(){
an=read(),m=read(); int x=read(),s=read(),g=get(m);
for(int i=1,w=g;i<m-1;++i,w=w*g%m) c[w]=i;
for(int i=1;i<=s;++i){
int x=read(); if(!x)continue;
a[c[x]]=1;
}
n=1,ln=0; while(n<m+m) n<<=1,++ln;
for(int i=0;i<n;++i) id[i]=id[i>>1]>>1 | ((i&1)<<ln-1);
ny=pw(n,Mod-2,Mod); ans[0]=1;
for(;an;an>>=1){
if(an&1) mul(ans,a);
mul(a,a);
}
printf("%d\n",(ans[c[x]]+Mod)%Mod);
return 0;
}
BZOJ3456 多項式求逆
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=262145,mo=1004535809;
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
int x=0,f=1; char ch=gc;
while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
return x*f;
}
int pw(int x,int k){
int r=1;
for(;k;k>>=1,x=1ll*x*x%mo) if(k&1)r=1ll*r*x%mo;
return r;
}
int id[N],an,ln,c[N],g[N],f[N],t[N],jc[N],ny[N];
void ntt(int *s,int n,int si){
for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
for(int i=1;i<n;i<<=1){
int wn=pw(3,si==1?(mo-1)/i/2:mo-1-(mo-1)/i/2);
for(int j=0;j<n;j+=i<<1){
int e=1,*b=s+j,*c=b+i;
for(int k=0;k<i;++k,e=1ll*e*wn%mo){
int x=b[k],y=1ll*c[k]*e%mo;
b[k]=(x+y)%mo,c[k]=(x-y+mo)%mo;
}
}
}
int ny=pw(n,mo-2);
if(si<0) for(int i=0;i<n;++i) s[i]=1ll*s[i]*ny%mo;
}
void pre(const int n,const int ln){
for(int i=0;i<n;++i) id[i]=id[i>>1]>>1 | ((i&1)<<ln-1);
}
void get_inv(int *a,int *b,const int u,const int ln){
if(u==1){
b[0]=pw(a[0],mo-2); return;
}
get_inv(a,b,u>>1,ln-1);
pre(u<<1,ln+1);
for(int i=0;i<u;++i) t[i]=a[i],t[i+u]=0;
ntt(t,u<<1,1),ntt(b,u<<1,1);
for(int i=0;i<(u<<1);++i) t[i]=(2ll-(LL)t[i]*b[i]%mo)*b[i]%mo;
ntt(t,u<<1,-1);
for(int i=0;i<u;++i) b[i]=t[i],b[i+u]=0;
}
int main(){
an=read(); jc[0]=ny[0]=1; int n,i;
for(i=1;i<=an;++i) jc[i]=1ll*jc[i-1]*i%mo,ny[i]=pw(jc[i],mo-2);
for(i=0;i<=an;++i) g[i]=1ll*pw(2,1ll*i*(i-1)/2%(mo-1))*ny[i]%mo;
for(i=1;i<=an;++i) c[i]=1ll*pw(2,1ll*i*(i-1)/2%(mo-1))*ny[i-1]%mo;
n=1,ln=0; while(n<=an) n<<=1,++ln;
get_inv(g,f,n,ln);
ntt(c,n<<1,1); ntt(f,n<<1,1);
for(i=0;i<(n<<1);++i) f[i]=1ll*f[i]*c[i]%mo;
ntt(f,n<<1,-1);
printf("%d\n",(1ll*f[an]*jc[an-1]%mo+mo)%mo);
return 0;
}
51nod 1348 CRT+NTT(還有分治?
#include<bits/stdc++.h>
using namespace std;
typedef long long LL;
const int N=131073,mo1=998244353,mo2=1004535809,mo=100003;
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
int x=0,f=1; char ch=gc;
while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
return x*f;
}
int A[N],B[N],a[18][N],b1[18][N],b2[18][N];
int id[N],an,o[N],mod;
int pw(int x,int k,int p){
int r=1;
for(;k;k>>=1,x=1ll*x*x%p) if(k&1) r=1ll*r*x%p;
return r;
}
void ntt(int *s,int n,int si){
for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
for(int i=1;i<n;i<<=1){
int wn=pw(3,si==1?(mod-1)/2/i:mod-1-(mod-1)/2/i,mod);
for(int j=0;j<n;j+=i<<1){
int e=1,*b=s+j,*c=b+i;
for(int k=0;k<i;++k,e=1ll*e*wn%mod){
int x=b[k],y=1ll*c[k]*e%mod;
b[k]=(x+y)%mod,c[k]=(x-y)%mod;
}
}
}
if(si<0) for(int ny=pw(n,mod-2,mod),i=0;i<n;++i) s[i]=1ll*s[i]*ny%mod;
}
LL mul(LL x,LL y,LL m){
LL tmp=(x*y-(LL)((double)x*y/m+1e-8)*m)%m;
return tmp<0?tmp+m:tmp;
}
LL merge(int m1,int m2){
LL m=1ll*mo1*mo2;
return ( mul((LL)mo2*pw(mo2,mo1-2,mo1),m1,m) + mul((LL)mo1*pw(mo1,mo2-2,mo2),m2,m) )%m;
}
void solve(int l,int r,int d){
if(l==r){a[d][0]=1,a[d][1]=o[l]; return;}
int mid=l+r>>1,m=r-l+1,ln=0,n=1;
while(n<=m+m) n<<=1,++ln;
solve(l,mid,d+1);
for(int i=0;i<=mid-l+1;++i) b1[d][i]=a[d+1][i],a[d+1][i]=0;
solve(mid+1,r,d+1);
for(int i=0;i<=r-mid;++i) b2[d][i]=a[d+1][i],a[d+1][i]=0;
for(int i=0;i<n;++i) id[i]=id[i>>1]>>1 | ((i&1)<<ln-1);
mod=mo1;
for(int i=0;i<=mid-l+1;++i) A[i]=b1[d][i];
for(int i=0;i<=r-mid;++i) B[i]=b2[d][i];
ntt(A,n,1); ntt(B,n,1);
for(int i=0;i<n;++i) A[i]=1ll*A[i]*B[i]%mod,B[i]=0;
ntt(A,n,-1);
for(int i=0;i<=m;++i) a[d][i]=A[i],A[i]=0;
mod=mo2;
for(int i=0;i<=mid-l+1;++i) A[i]=b1[d][i];
for(int i=0;i<=r-mid;++i) B[i]=b2[d][i];
ntt(A,n,1); ntt(B,n,1);
for(int i=0;i<n;++i) A[i]=1ll*A[i]*B[i]%mod,B[i]=0;
ntt(A,n,-1);
for(int i=0;i<=m;++i) a[d][i]=merge(a[d][i],A[i])%mo,A[i]=0;
}
int main(){
an=read(); int q=read();
for(int i=1;i<=an;++i) o[i]=read()%mo;
solve(1,an,0);
while(q--) printf("%d\n",(a[0][read()]+mo)%mo);
return 0;
}
51nod 1172 任意模數fft 打了mtt
#include<bits/stdc++.h>
using namespace std;
typedef long double ld;
typedef long long LL;
const int N=131073,mo=1e9+7;
const ld pi=acos(-1);
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
int x=0,f=1; char ch=gc;
while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
return x*f;
}
struct P{
ld x,y;
P(ld a=0,ld b=0){x=a,y=b;}
inline P con(){return P(x,-y);}
}A[N],B[N],dfa[N],dfb[N],dfc[N],dfd[N];
P operator+(P x,P y){return P(x.x+y.x,x.y+y.y);}
P operator-(P x,P y){return P(x.x-y.x,x.y-y.y);}
P operator*(P x,P y){return P(x.x*y.x-x.y*y.y,x.y*y.x+x.x*y.y);}
int id[N],an,n,ln,ny[N>>1],jc[N>>1],a[N],c[N];
inline void fft(P *s,int si){
for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
for(int i=1;i<n;i<<=1){
P wn=P(cos(pi/i),si*sin(pi/i));
for(int j=0;j<n;j+=i<<1){
P e=P(1,0),*b=s+j,*c=b+i;
for(int k=0;k<i;++k,e=e*wn){
P x=b[k],y=c[k]*e;
b[k]=x+y,c[k]=x-y;
}
}
}
//if(si<0) for(int i=0;i<n;++i) s[i].x/=n;
}
inline void mul(int *x,int *y){
for(int i=0;i<an;++i)
A[i]=P(x[i]&32767,x[i]>>15),B[i]=P(y[i]&32767,y[i]>>15);
fft(A,1); fft(B,1);
for(int i=0;i<n;++i){
int j=n-i & n-1;
P p=(A[i]+A[j].con())*P(0.5,0),q=(A[i]-A[j].con())*P(0,-0.5);
P r=(B[i]+B[j].con())*P(0.5,0),s=(B[i]-B[j].con())*P(0,-0.5);
dfa[i]=p*r,dfb[i]=p*s,dfc[i]=q*r,dfd[i]=q*s;
}
for(int i=0;i<n;++i)
A[i]=dfa[i]+dfb[i]*P(0,1),B[i]=dfc[i]+dfd[i]*P(0,1);
fft(A,-1); fft(B,-1);
for(int i=0;i<an;++i){
int p=(LL)(A[i].x/n+0.5)%mo,q=(LL)(A[i].y/n+0.5)%mo,r=(LL)(B[i].x/n+0.5)%mo,s=(LL)(B[i].y/n+0.5)%mo;
printf("%d\n",( ((LL)s<<30)+((LL)(q+r)<<15)+p )%mo);
}
}
int main(){
an=read(); int k=read(),i;
ny[0]=ny[1]=c[0]=1,c[1]=k;
for(i=2;i<an;++i) ny[i]=1ll*(mo-mo/i)*ny[mo%i]%mo,c[i]=1ll*c[i-1]*(k+i-1)%mo;
for(i=2;i<an;++i) ny[i]=1ll*ny[i]*ny[i-1]%mo,c[i]=1ll*c[i]*ny[i]%mo;
for(i=0;i<an;++i) a[i]=read();
n=1,ln=0; while(n<=an+an) n<<=1,++ln;
for(i=0;i<n;++i) id[i]=id[i>>1]>>1 | ((i&1)<<ln-1);
mul(a,c);
return 0;
}
bzoj 3625 多項式開根
#include<bits/stdc++.h>
using namespace std;
const int N=262145,mo=998244353,n2=499122177;
char O[1<<14],*S=O,*T=O;
#define gc (S==T&&(T=(S=O)+fread(O,1,1<<14,stdin),S==T)?-1:*S++)
inline int read(){
int x=0,f=1; char ch=gc;
while(ch<'0' || ch>'9'){if(ch=='-')f=-1; ch=gc;}
while(ch>='0' && ch<='9'){x=(x<<1)+(x<<3)+(ch^48); ch=gc;}
return x*f;
}
int id[N],c[N],d[N],a[N],b[N];
int pw(int x,int k){
int r=1;
for(;k;k>>=1,x=1ll*x*x%mo) if(k&1) r=1ll*r*x%mo;
return r;
}
void ntt(int *s,int n,int si){
for(int i=1;i<n;++i) if(i<id[i]) swap(s[i],s[id[i]]);
for(int i=1;i<n;i<<=1){
int wn=pw(3,si==1?(mo-1)/2/i:mo-1-(mo-1)/2/i);
for(int j=0;j<n;j+=i<<1){
int e=1,*b=s+j,*c=b+i;
for(int k=0;k<i;++k,e=1ll*e*wn%mo){
int x=b[k],y=1ll*c[k]*e%mo;
b[k]=(x+y)%mo,c[k]=(x-y)%mo;
}
}
}
if(si<0) for(int ny=pw(n,mo-2),i=0;i<n;++i) s[i]=1ll*s[i]*ny%mo;
}
void inv(int *a,int *b,int n,int ln){
if(n==1) return void(b[0]=pw(a[0],mo-2));
inv(a,b,n>>1,ln-1);
for(int i=0;i<n;++i) c[i]=a[i],c[i+n]=0;
for(int i=0;i< n<<1;++i) id[i]=id[i>>1]>>1|((i&1)<<ln);
ntt(c,n<<1,1); ntt(b,n<<1,1);
for(int i=0;i< n<<1;++i) b[i]=1ll*b[i]*(2-1ll*c[i]*b[i]%mo)%mo;
ntt(b,n<<1,-1);
memset(b+n,0,n*sizeof(int));
}
void Sqrt(int *a,int *b,int n,int ln){
if(n==1) return void(b[0]=1);
Sqrt(a,b,n>>1,ln-1);
memset(d,0,n*2*sizeof(int));
inv(b,d,n,ln);
for(int i=0;i<n;++i) c[i]=a[i],c[i+n]=0;
ntt(c,n<<1,1); ntt(b,n<<1,1); ntt(d,n<<1,1);
for(int i=0;i< n<<1;++i) b[i]=(1ll*c[i]*d[i]%mo+b[i])%mo*n2%mo;
ntt(b,n<<1,-1);
memset(b+n,0,n*sizeof(int));
}
int main(){
int n=read(),m=read(); a[0]=1;
for(int i=1;i<=n;++i){
int x=read(); if(x<=m)a[x]=mo-4;
}
int ln=0; n=1; for(;n<=m;++ln,n<<=1);
Sqrt(a,b,n,ln); b[0]=(b[0]+1)%mo;
memset(a,0,n*sizeof(int)); inv(b,a,n,ln);
for(int i=1;i<=m;++i) printf("%d\n",((a[i]<<1)%mo+mo)%mo);
return 0;
}