bzoj-3625 小朋友和二叉樹
阿新 • • 發佈:2018-12-24
題意:
給出一個大小為n的集合C;
對於i=1...m計算有多少二叉樹滿足每個節點的權值都在集合C中且所有結點權值和為i;
對998244353取模,左右兒子有別;
題解:
生成函式系列題解之三?
這題先對C搞個生成函式吧,令其為C(x);
而我們要求的是樹的計數的函式F(x);
列一下方程,F(x)=C(x)*F^2(x)+1;
F^2(x)表示它的左右兒子的方案,C(x)是限制它自己的權值,+1是因為空樹有一個常數項;
這個方程式很有道理的,不理解就再理解一下;
然後解一下二次方程。。。
解多項式方程?
上求根公式,F(x)=(1±√ 1-4C(x))/2C(x);
二次方程可能有兩個解,但是這個方程只有一個;
因為顯然C(x)無常數項,開根之後出來有一個1,而分母又沒有常數項;
只有取減號時將常數項減掉才能做除法;
多項式開根的具體方法還是倍增;
過程中每一層都要常數次呼叫FFT和多項式求逆;
時間複雜度?T(n)=O(nlogn)+T(n/2)=O(nlogn);
這個複雜度簡直毒瘤。。。至於原因。。。這個複雜度支援各種巢狀。。;
樹套樹都不能無限套而這東西簡直可怕;
程式碼:
#include<math.h> #include<stdio.h> #include<string.h> #include<algorithm> #define N 261244<<1 using namespace std; typedef long long ll; const int mod=998244353; const int div2=499122177; int a[N],b[N],c[N]; int pow(int x,int y) { int ret=1; while(y) { if(y&1) ret=(ll)ret*x%mod; x=(ll)x*x%mod; y>>=1; } return ret; } void NTT(int *a,int len,int type) { int i,j,t,h; for(i=0,t=0;i<len;i++) { if(i>t) swap(a[i],a[t]); for(j=(len>>1);(t^=j)<j;j>>=1); } for(h=2;h<=len;h<<=1) { int wn=pow(5,(mod-1)/h); for(i=0;i<len;i+=h) { int w=1; for(j=0;j<(h>>1);j++,w=(ll)w*wn%mod) { int temp=(ll)w*a[i+j+(h>>1)]%mod; a[i+j+(h>>1)]=(a[i+j]-temp+mod)%mod; a[i+j]=(a[i+j]+temp)%mod; } } } if(type==-1) { for(i=1;i<(len>>1);i++) swap(a[i],a[len-i]); int inv=pow(len,mod-2); for(i=0;i<len;i++) a[i]=(ll)a[i]*inv%mod; } } void inv(int *a,int *b,int len) { if(len==1) { b[0]=pow(a[0],mod-2); return ; } inv(a,b,len>>1); static int temp[N]; memcpy(temp,a,sizeof(int)*len); memset(temp+len,0,sizeof(int)*len); NTT(temp,len<<1,1),NTT(b,len<<1,1); for(int i=0;i<len<<1;i++) b[i]=(ll)b[i]*(2-(ll)temp[i]*b[i]%mod+mod)%mod; NTT(b,len<<1,-1); memset(b+len,0,sizeof(ll)*len); } void sqrt(int *a,int *b,int len) { static int tempa[N],tempb[N]; if(len==1) { b[0]=1; return ; } sqrt(a,b,len>>1); memset(tempb,0,sizeof(int)*len); memset(tempb+len,0,sizeof(int)*len); inv(b,tempb,len); memcpy(tempa,a,sizeof(int)*len); memset(tempa+len,0,sizeof(int)*len); NTT(tempa,len<<1,1),NTT(b,len<<1,1),NTT(tempb,len<<1,1); for(int i=0;i<len<<1;i++) b[i]=(ll)(b[i]+(ll)tempa[i]*tempb[i]%mod)%mod*div2%mod; NTT(b,len<<1,-1); memset(b+len,0,sizeof(int)*len); } int main() { int n,m,i,j,k,len; scanf("%d%d",&n,&m); for(i=1;i<=n;i++) { scanf("%d",&k); if(k<=m) a[k]++; } for(i=1<<30;i;i>>=1) if(m&i) {len=i<<1;break;} for(i=0;i<len;i++) if(a[i]) a[i]=mod-4; a[0]++; sqrt(a,b,len); memcpy(a,b,sizeof(int)*len); a[0]++; memset(b,0,sizeof(int)*len); inv(a,b,len); memcpy(a,b,sizeof(int)*len); for(i=1;i<=m;i++) printf("%d\n",(a[i]+a[i])%mod); return 0; }