[SDOI2017]遺忘的集合(多項式ln+生成函式+莫比烏斯反演)
阿新 • • 發佈:2020-08-03
[SDOI2017]遺忘的集合(多項式ln+生成函式+莫比烏斯反演)
題面
略
分析
設$a_i=[i \in S]$,那麼元素$i$的生成函式為$(\frac{1}{1-xi})$,答案的生成函式為$f(x)=\prod_{i \geq 1}(\frac{1}{1-xi})$. 現在題目已經給出了$f(x)$的各項係數,求$a_i$
為了把乘法化成加法,兩邊取對數,得到: \(-\ln F(x)=\sum_{i \geq 1}a_i\ln(1-x^i)\)
根據$\ln$的泰勒展開$\ln(1-x)=-\sum_{j \geq 1}\frac{x^j}$
\(\ln F(x)=\sum_{i\geq 1}a_i \sum_{j \geq 1}\frac{x^{ij}}{j}\)
交換求和順序,令$ij=k$,則$j=\frac$
\(\ln F(x)=\sum_{k \geq 1}x^k \sum_{i|k} a_i \frac{i}{k}\)
設$\ln F(x)$的第$i$項係數為$g_i$,則$ng_n=\sum_{i|n} ia_i$
看到約數求和,想到莫比烏斯反演:\(na_n=\sum_{i|n} ig_i \mu(\frac{n}{i})\). 那麼我們就可以求出$na_n$,又因為$a_n$只能為0或1,當反演結果不為0的時候輸出即可。顯然這個解字典序最小。
程式碼
#include<iostream> #include<cstdio> #include<cstring> #include<cmath> #include<vector> #define maxn (1<<19) using namespace std; const double pi=acos(-1.0); typedef long long ll; ll mod; inline ll fast_pow(ll x,ll k){ ll ans=1; while(k){ if(k&1) ans=ans*x%mod; x=x*x%mod; k>>=1; } return ans; } inline ll inv(ll x){ return fast_pow(x,mod-2); } struct com{ double real; double imag; com(){ } com(double _real,double _imag){ real=_real; imag=_imag; } friend com operator + (com p,com q){ return com(p.real+q.real,p.imag+q.imag); } friend com operator - (com p,com q){ return com(p.real-q.real,p.imag-q.imag); } friend com operator * (com p,com q){ return com(p.real*q.real-p.imag*q.imag,p.real*q.imag+p.imag*q.real); } friend com operator * (com p,double k){ return com(p.real*k,p.imag*k); } friend com operator / (com p,double k){ return com(p.real/k,p.imag/k); } inline com conj(){ return com(real,-imag); } }; int rev[maxn+5]; com w[maxn+5]; void fft(com *x,int n){ for(int i=0;i<n;i++) if(i<rev[i]) swap(x[i],x[rev[i]]); for(int len=1;len<n;len*=2){ int sz=len*2; for(int l=0;l<n;l+=sz){ int r=l+len-1; for(int i=l;i<=r;i++){ com tmp=x[i+len]; x[i+len]=x[i]-tmp*w[n/sz*(i-l)]; x[i]=x[i]+tmp*w[n/sz*(i-l)]; } } } } void mul(ll *a,ll *b,ll *c,int n,int m){ static com p[maxn+5],q[maxn+5],r[maxn+5],s[maxn+5]; int N=1,L=0; while(N<n+m-1){ N*=2; L++; } for(int i=0;i<n;i++){ ll ta=(a[i]+mod)%mod; p[i]=com(ta>>15,ta&((1<<15)-1)); } for(int i=n;i<N;i++) p[i]=com(0,0); for(int i=0;i<m;i++){ ll tb=(b[i]+mod)%mod; q[i]=com(tb>>15,tb&((1<<15)-1)); } for(int i=m;i<N;i++) q[i]=com(0,0); for(int i=0;i<N;i++) w[i]=com(cos(2*pi*i/N),sin(2*pi*i/N)); for(int i=0;i<N;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(L-1)); fft(p,N); fft(q,N); for(int i=0;i<N;i++){ int j=(i==0?0:N-i); com da=(p[i]+p[j].conj())*com(0.5,0); com db=(p[i]-p[j].conj())*com(0,-0.5); com dc=(q[i]+q[j].conj())*com(0.5,0); com dd=(q[i]-q[j].conj())*com(0,-0.5); r[j]=da*dc+da*dd*com(0,1); s[j]=db*dc+db*dd*com(0,1); } fft(r,N); fft(s,N); for(int i=0;i<n+m-1;i++){ ll ac=(ll)(r[i].real/N+0.5)%mod; ll ad=(ll)(r[i].imag/N+0.5)%mod; ll bc=(ll)(s[i].real/N+0.5)%mod; ll bd=(ll)(s[i].imag/N+0.5)%mod; c[i]=(((ac<<30)+((ad+bc)<<15)+bd)%mod+mod)%mod; } } void poly_inv(ll *f,ll *g,int n){ static ll tmp[maxn+5]; if(n==1){ g[0]=inv(f[0]); return; } poly_inv(f,g,(n+1)/2); mul(f,g,tmp,n,n); mul(tmp,g,tmp,n,n); for(int i=0;i<n;i++) g[i]=(2*g[i]-tmp[i]+mod)%mod;//tmp[i]=f[i]*g[i]^2 } void poly_deriv(ll *f,ll *g,int n){ for(int i=1;i<n;i++) g[i-1]=f[i]*i%mod; g[n-1]=0; } void poly_inter(ll *f,ll *g,int n){ for(int i=n-1;i>=1;i--) g[i]=f[i-1]*inv(i)%mod; g[0]=0; } void poly_ln(ll *f,ll *g,int n){ static ll inv_ln[maxn+5]; poly_deriv(f,g,n); poly_inv(f,inv_ln,n); mul(g,inv_ln,g,n,n); poly_inter(g,g,n*2); } int cnt; int mu[maxn+5],prime[maxn+5]; bool vis[maxn+5]; void sieve_mu(int n){ mu[1]=1; for(int i=2;i<=n;i++){ if(!vis[i]){ prime[++cnt]=i; mu[i]=-1; } for(int j=1;j<=cnt&&i*prime[j]<=n;j++){ vis[i*prime[j]]=1; if(i%prime[j]==0){ mu[i*prime[j]]=0; break; }else mu[i*prime[j]]=-mu[i]; } } } int n; ll f[maxn+5],lnf[maxn+5]; ll a[maxn+5];//實際上存的是a[i]*i vector<int>ans; int main(){ scanf("%d",&n); scanf("%lld",&mod); sieve_mu(n); f[0]=1; for(int i=1;i<=n;i++) scanf("%lld",&f[i]); poly_ln(f,lnf,n+1);//對f的生成函式求ln for(int i=0;i<=n;i++) lnf[i]=lnf[i]*i%mod; for(int i=1;i<=n;i++){ for(int j=1;j*i<=n;j++) a[i*j]+=lnf[i]*mu[j];//莫比烏斯反演 } for(int i=1;i<=n;i++) if(a[i]) ans.push_back(i); printf("%d\n",(int)ans.size()); for(int i=0;i<(int)ans.size();i++) printf("%d ",ans[i]); }