bzoj3625(NTT+多項式求逆+多項式開根)
阿新 • • 發佈:2018-12-24
這題是我搜NTT搜到的,當時就看到“多項式開根”這樣的標題,於是找到了L-leader的部落格,補了下冪級數的東西,用兩節數學課學會了。
我再看題解,好像都是教我怎麼開方,求逆的,然後又拖了幾天。終於昨晚睡不著,突然就想到了。。。
先介紹一下生成函式。
簡單的說,就是一個數組a[0..n],可以生成一個多項式函式(冪級數)
題意是給你二叉樹每個節點可能的點權集合C,元素都<=1e5,對於所有1<=s<=m,有種不同的二叉樹滿足點權和為s,答案模一個費馬素數。
設g[i]為一個01陣列,表示i是否在C出現,f[i]為權值和為i的方案數,即是答案。F為f的生成函式,G為g的生成函式,根據題意g[0]=0,因為存在空樹,f[0]=1;
我們可以列舉二叉樹根的權值,剩下左右兒子為子問題,就有
就可以大概知道
根據f[0]=1,g[1]=0,所以有
通過解一元二次方程,再結合f[0]=1,g[1]=0
然後就是多項式求逆和多項式開根了。
#include <iostream>
#include <fstream>
#include <algorithm>
#include <cmath>
#include <ctime>
#include <cstdio>
#include <cstdlib>
#include <cstring>
using namespace std;
#define mmst(a, b) memset(a, b, sizeof(a))
#define mmcp(a, b) memcpy(a, b, sizeof(b))
typedef long long LL;
const int p=998244353,I2=499122177;
const int N=800400;
int cheng(int a,int b)
{
int res=1;
for(;b;b>>=1,a=(LL)a*a %p)
if(b&1)
res=(LL)res*a%p;
return res;
}
int n,rev[N];
void init(int lim)
{
n=1;
int k=-1;
while(n<lim)
n<<=1,k++;
for(int i=0;i<n;i++)
rev[i]=(rev[i>>1] >> 1) | ((i&1)<<k);
}
void ntt(int *a,int ops)
{
for(int i=0;i<n;i++)
if(i<rev[i])
swap(a[i],a[rev[i]]);
for(int l=2;l<=n;l<<=1)
{
int m=l>>1,wn;
if(ops)
wn=cheng(3,(p-1)/l);
else
wn=cheng(3,p-1-(p-1)/l);
for(int i=0;i<n;i+=l)
{
int w=1;
for(int k=0;k<m;k++)
{
int t=(LL)a[i+k+m]*w%p;
a[i+k+m]=(a[i+k]-t+p)%p;
a[i+k]=(a[i+k]+t)%p;
w=(LL)w*wn%p;
}
}
}
if(!ops)
{
int Inv=cheng(n,p-2);
for(int i=0;i<n;i++)
a[i]=(LL)a[i]*Inv%p;
}
}
int g[N];
int mx=1,by,nn,mm;
int X[N],Y[N],sqr[N],A[N],B[N],C[N];
void Inverse(int *a,int *b,LL len)
{
if(len==1)
{
b[0]=cheng(a[0],p-2);
return;
}
Inverse(a,b,len>>1);
init(2*len);
for(int i=0;i<len;i++)
X[i]=a[i];
for(int i=0;i<(len>>1);i++)
Y[i]=b[i];
ntt(X,1);
ntt(Y,1);
for(int i=0;i<n;i++)
X[i]=(2ll*Y[i]%p-(LL)X[i]*Y[i]%p*Y[i]%p+p)%p;
ntt(X,0);
for(int i=0;i<n;i++)
{
if(i>=len)
b[i]=0;
else
b[i]=X[i];
X[i]=Y[i]=0;
}
}
void Sqrt(int len)
{
if(len==1)
{
sqr[0]=1;//本題被開方的多項式常數項為1
return;
}
Sqrt(len>>1);
Inverse(sqr,A,len);
for(int i=0;i<(len>>1);i++)
B[i]=sqr[i];
for(int i=0;i<len;i++)
C[i]=g[i];
init(len*2);
ntt(A,1);
ntt(B,1);
ntt(C,1);
for(int i=0;i<n;i++)
A[i]=(1ll*C[i]+(LL)B[i]*B[i])%p*I2%p*A[i]%p;
ntt(A,0);
for(int i=0;i<n;i++)
{
sqr[i]=A[i];
if(i>=len)
sqr[i]=0;
A[i]=B[i]=C[i]=0;
}
}
int main()
{
cin>>nn>>mm;
while(mx<=mm)
mx<<=1;
for(int i=1;i<=nn;i++)
{
scanf("%d",&by);
if(by<=mm)
g[by]=1;
}
for(int i=0;i<mx;i++)
if(g[i])
g[i]=p-4;
g[0]=1;
Sqrt(mx);
sqr[0]=(sqr[0]+1)%p;
mmst(g,0);
Inverse(sqr,g,mx);
for(int i=1;i<=mm;i++)
printf("%d\n",(g[i]+g[i])%p);
return 0;
}