1. 程式人生 > >[CF438E] 小朋友和二叉樹

[CF438E] 小朋友和二叉樹

Description

給定一個整數集合 \(c\),對於每個 \(i\in[1,m]\),求有多少種不同的帶點權的二叉樹使得這棵樹點權和為 \(i\) 並且頂點的點權全部在集合 \(c\) 中。\(m\leq 10^5\)

Solution

\(f[i]\) 為點權為 \(i\) 的二叉樹的方案, \(c[i]\) 表示 \(i\) 是否在集合 \(c\) 中。

所以 \(f[i]=\sum\limits_{j=1}^{i} c[j]\cdot\sum\limits_{p=0}^{i-j}f[p]\cdot \sum\limits_{k=0}^{i-j-p}f[k],f[0]=1\)

發現這是個卷積形式,也就是說 \(f[i+j+k]=c[i]\cdot f[j]\cdot f[k]\)

,即 \(f=c\times f\times f\)

解一下方程,\(f=\frac{1\pm \sqrt{1-4c}}{2c}\)

然而 \(c\) 的常數項為 \(0\),所以不能求逆。嘗試分子有理化解得 \(f=\frac2{1\pm\sqrt{1-4c}}\)

\(x=0\) 時,\(c=0,f=1\),所以分母只能取正號。

求逆+開根即可。

Code

#include<bits/stdc++.h>
using std::min;
using std::max;
using std::swap;
using std::vector;
typedef double db;
typedef long long ll;
#define pb(A) push_back(A)
#define pii std::pair<int,int>
#define all(A) A.begin(),A.end()
#define mp(A,B) std::make_pair(A,B)
const int N=4e5+5;
const int mod=998244353;
#define inv(x) ksm(x,mod-2)

int lim,rev[N],f[N];
int n,m,tmpa[N],c[N];
int a[N],b[N],tmpb[N];

int ksm(int a,int b,int ans=1){
    while(b){
        if(b&1) ans=1ll*ans*a%mod;
        a=1ll*a*a%mod;b>>=1;
    } return ans;
}

void ntt(int *f,int opt){
    for(int i=0;i<lim;i++) if(i<rev[i]) swap(f[i],f[rev[i]]);
    for(int mid=1;mid<lim;mid<<=1){
        int tmp=ksm(3,(mod-1)/(mid<<1));
        if(opt<0) tmp=inv(tmp);
        for(int R=mid<<1,j=0;j<lim;j+=R){
            int w=1;
            for(int k=0;k<mid;k++,w=1ll*w*tmp%mod){
                int x=f[j+k],y=1ll*w*f[j+k+mid]%mod;
                f[j+k]=(x+y)%mod,f[j+k+mid]=(mod+x-y)%mod;
            }
        }
    } if(opt<0)
        for(int in=inv(lim),i=0;i<lim;i++) f[i]=1ll*f[i]*in%mod;
}

void get(int n){
    lim=1;while(lim<=n) lim<<=1;
    for(int i=1;i<lim;i++) rev[i]=(rev[i>>1]>>1)|(i&1?lim>>1:0);
}

void solveinv(int len,int *a,int *b){
    if(len==1) return b[0]=inv(a[0]),void();
    solveinv(len>>1,a,b);
    get(len);
    for(int i=0;i<len;i++) tmpa[i]=a[i];
    ntt(tmpa,1),ntt(b,1);
    for(int i=0;i<lim;i++) b[i]=1ll*b[i]*(2ll-1ll*tmpa[i]*b[i]%mod+mod)%mod;
    ntt(b,-1);
    for(int i=len;i<lim;i++) b[i]=0;
    for(int i=0;i<lim;i++) tmpa[i]=0;
}

void solvesqr(int len,int *a,int *b){
    if(len==1) return b[0]=1,void();
    solvesqr(len>>1,a,b);
    solveinv(len,b,tmpb);
    get(len);
    for(int i=0;i<len;i++) tmpa[i]=a[i];
    ntt(tmpb,1),ntt(tmpa,1);
    for(int i=0;i<lim;i++) tmpa[i]=1ll*tmpa[i]*tmpb[i]%mod;
    ntt(tmpa,-1);
    for(int i=0,inv2=mod+1>>1;i<lim;i++) b[i]=1ll*(tmpa[i]+b[i])%mod*inv2%mod;
    for(int i=len;i<lim;i++) b[i]=0;
    for(int i=0;i<lim;i++) tmpa[i]=tmpb[i]=0;
}

int getint(){
    int X=0,w=0;char ch=getchar();
    while(!isdigit(ch))w|=ch=='-',ch=getchar();
    while( isdigit(ch))X=X*10+ch-48,ch=getchar();
    if(w) return -X;return X;
}

signed main(){
    n=getint(),m=getint();
    for(int i=1;i<=n;i++){
        int x=getint();
        a[x]=1;
    }
    for(int i=1;i<=m;i++) if(a[i]) a[i]=mod-4;
    get(m);a[0]=1;solvesqr(lim,a,b);
    get(m);b[0]++;solveinv(lim,b,c);
    for(int i=1;i<=m;i++) printf("%lld\n",2ll*c[i]%mod);
    return 0;
}