[CF438E] 小朋友和二叉樹
阿新 • • 發佈:2019-01-02
Description
給定一個整數集合 \(c\),對於每個 \(i\in[1,m]\),求有多少種不同的帶點權的二叉樹使得這棵樹點權和為 \(i\) 並且頂點的點權全部在集合 \(c\) 中。\(m\leq 10^5\)。
Solution
設 \(f[i]\) 為點權為 \(i\) 的二叉樹的方案, \(c[i]\) 表示 \(i\) 是否在集合 \(c\) 中。
所以 \(f[i]=\sum\limits_{j=1}^{i} c[j]\cdot\sum\limits_{p=0}^{i-j}f[p]\cdot \sum\limits_{k=0}^{i-j-p}f[k],f[0]=1\)
發現這是個卷積形式,也就是說 \(f[i+j+k]=c[i]\cdot f[j]\cdot f[k]\)
解一下方程,\(f=\frac{1\pm \sqrt{1-4c}}{2c}\)
然而 \(c\) 的常數項為 \(0\),所以不能求逆。嘗試分子有理化解得 \(f=\frac2{1\pm\sqrt{1-4c}}\)
當 \(x=0\) 時,\(c=0,f=1\),所以分母只能取正號。
求逆+開根即可。
Code
#include<bits/stdc++.h> using std::min; using std::max; using std::swap; using std::vector; typedef double db; typedef long long ll; #define pb(A) push_back(A) #define pii std::pair<int,int> #define all(A) A.begin(),A.end() #define mp(A,B) std::make_pair(A,B) const int N=4e5+5; const int mod=998244353; #define inv(x) ksm(x,mod-2) int lim,rev[N],f[N]; int n,m,tmpa[N],c[N]; int a[N],b[N],tmpb[N]; int ksm(int a,int b,int ans=1){ while(b){ if(b&1) ans=1ll*ans*a%mod; a=1ll*a*a%mod;b>>=1; } return ans; } void ntt(int *f,int opt){ for(int i=0;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]); for(int mid=1;mid<lim;mid<<=1){ int tmp=ksm(3,(mod-1)/(mid<<1)); if(opt<0) tmp=inv(tmp); for(int R=mid<<1,j=0;j<lim;j+=R){ int w=1; for(int k=0;k<mid;k++,w=1ll*w*tmp%mod){ int x=f[j+k],y=1ll*w*f[j+k+mid]%mod; f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod; } } } if(opt<0) for(int in=inv(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod; } void get(int n){ lim=1;while(lim<=n) lim<<=1; for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0); } void solveinv(int len,int *a,int *b){ if(len==1) return b[0]=inv(a[0]),void(); solveinv(len>>1,a,b); get(len); for(int i=0;i<len;i++) tmpa[i]=a[i]; ntt(tmpa,1),ntt(b,1); for(int i=0;i<lim;i++) b[i]=1ll*b[i]*(2ll-1ll*tmpa[i]*b[i]%mod+mod)%mod; ntt(b,-1); for(int i=len;i<lim;i++) b[i]=0; for(int i=0;i<lim;i++) tmpa[i]=0; } void solvesqr(int len,int *a,int *b){ if(len==1) return b[0]=1,void(); solvesqr(len>>1,a,b); solveinv(len,b,tmpb); get(len); for(int i=0;i<len;i++) tmpa[i]=a[i]; ntt(tmpb,1),ntt(tmpa,1); for(int i=0;i<lim;i++) tmpa[i]=1ll*tmpa[i]*tmpb[i]%mod; ntt(tmpa,-1); for(int i=0,inv2=mod+1>>1;i<lim;i++) b[i]=1ll*(tmpa[i]+b[i])%mod*inv2%mod; for(int i=len;i<lim;i++) b[i]=0; for(int i=0;i<lim;i++) tmpa[i]=tmpb[i]=0; } int getint(){ int X=0,w=0;char ch=getchar(); while(!isdigit(ch))w|=ch=='-',ch=getchar(); while( isdigit(ch))X=X*10+ch-48,ch=getchar(); if(w) return -X;return X; } signed main(){ n=getint(),m=getint(); for(int i=1;i<=n;i++){ int x=getint(); a[x]=1; } for(int i=1;i<=m;i++) if(a[i]) a[i]=mod-4; get(m);a[0]=1;solvesqr(lim,a,b); get(m);b[0]++;solveinv(lim,b,c); for(int i=1;i<=m;i++) printf("%lld\n",2ll*c[i]%mod); return 0; }