1. 程式人生 > >bzoj-3625 小朋友和二叉樹

bzoj-3625 小朋友和二叉樹

題意:

給出一個大小為n的集合C;

對於i=1...m計算有多少二叉樹滿足每個節點的權值都在集合C中且所有結點權值和為i;

對998244353取模,左右兒子有別;

題解:

生成函式系列題解之三?

這題先對C搞個生成函式吧,令其為C(x);

而我們要求的是樹的計數的函式F(x);

列一下方程,F(x)=C(x)*F^2(x)+1;

F^2(x)表示它的左右兒子的方案,C(x)是限制它自己的權值,+1是因為空樹有一個常數項;

這個方程式很有道理的,不理解就再理解一下;

然後解一下二次方程。。。

解多項式方程?

上求根公式,F(x)=(1±√ 1-4C(x))/2C(x);

二次方程可能有兩個解,但是這個方程只有一個;

因為顯然C(x)無常數項,開根之後出來有一個1,而分母又沒有常數項;

只有取減號時將常數項減掉才能做除法;

多項式開根的具體方法還是倍增;

過程中每一層都要常數次呼叫FFT和多項式求逆;

時間複雜度?T(n)=O(nlogn)+T(n/2)=O(nlogn);

這個複雜度簡直毒瘤。。。至於原因。。。這個複雜度支援各種巢狀。。;

樹套樹都不能無限套而這東西簡直可怕;

程式碼:

#include<math.h>
#include<stdio.h>
#include<string.h>
#include<algorithm>
#define N 261244<<1
using namespace std;
typedef long long ll;
const int mod=998244353;
const int div2=499122177;
int a[N],b[N],c[N];
int pow(int x,int y)
{
	int ret=1;
	while(y)
	{
		if(y&1)
			ret=(ll)ret*x%mod;
		x=(ll)x*x%mod;
		y>>=1;
	}
	return ret;
}
void NTT(int *a,int len,int type)
{
	int i,j,t,h;
	for(i=0,t=0;i<len;i++)
	{
		if(i>t)	swap(a[i],a[t]);
		for(j=(len>>1);(t^=j)<j;j>>=1);
	}
	for(h=2;h<=len;h<<=1)
	{
		int wn=pow(5,(mod-1)/h);
		for(i=0;i<len;i+=h)
		{
			int w=1;
			for(j=0;j<(h>>1);j++,w=(ll)w*wn%mod)
			{
				int temp=(ll)w*a[i+j+(h>>1)]%mod;
				a[i+j+(h>>1)]=(a[i+j]-temp+mod)%mod;
				a[i+j]=(a[i+j]+temp)%mod;
			}
		}
	}
	if(type==-1)
	{
		for(i=1;i<(len>>1);i++)
			swap(a[i],a[len-i]);
		int inv=pow(len,mod-2);
		for(i=0;i<len;i++)
			a[i]=(ll)a[i]*inv%mod;
	}
}
void inv(int *a,int *b,int len)
{
	if(len==1)
	{
		b[0]=pow(a[0],mod-2);
		return ;
	}
	inv(a,b,len>>1);
	static int temp[N];
	memcpy(temp,a,sizeof(int)*len);
	memset(temp+len,0,sizeof(int)*len);
	NTT(temp,len<<1,1),NTT(b,len<<1,1);
	for(int i=0;i<len<<1;i++)	b[i]=(ll)b[i]*(2-(ll)temp[i]*b[i]%mod+mod)%mod;
	NTT(b,len<<1,-1);
	memset(b+len,0,sizeof(ll)*len);
}
void sqrt(int *a,int *b,int len)
{
	static int tempa[N],tempb[N];
	if(len==1)
	{
		b[0]=1;
		return ;
	}
	sqrt(a,b,len>>1);
	memset(tempb,0,sizeof(int)*len);
	memset(tempb+len,0,sizeof(int)*len);
	inv(b,tempb,len);
	memcpy(tempa,a,sizeof(int)*len);
	memset(tempa+len,0,sizeof(int)*len);
	NTT(tempa,len<<1,1),NTT(b,len<<1,1),NTT(tempb,len<<1,1);
	for(int i=0;i<len<<1;i++)	b[i]=(ll)(b[i]+(ll)tempa[i]*tempb[i]%mod)%mod*div2%mod;
	NTT(b,len<<1,-1);
	memset(b+len,0,sizeof(int)*len);
}
int main()
{
	int n,m,i,j,k,len;
	scanf("%d%d",&n,&m);
	for(i=1;i<=n;i++)
	{
		scanf("%d",&k);
		if(k<=m)
		a[k]++;
	}
	for(i=1<<30;i;i>>=1)
		if(m&i)
			{len=i<<1;break;}
	for(i=0;i<len;i++)
		if(a[i])
			a[i]=mod-4;
	a[0]++;
	sqrt(a,b,len);
	memcpy(a,b,sizeof(int)*len);
	a[0]++;	
	memset(b,0,sizeof(int)*len);
	inv(a,b,len);
	memcpy(a,b,sizeof(int)*len);
	for(i=1;i<=m;i++)
		printf("%d\n",(a[i]+a[i])%mod);
	return 0;
}