1. 程式人生 > >bzoj 4555:[Tjoi2016&Heoi2016]求和 多項式求逆

bzoj 4555:[Tjoi2016&Heoi2016]求和 多項式求逆

       如果沒有2^j和j!,那麼題目就是求Σ(i=0,n)Bi(Bi表示第i個貝爾數),而Bi=Σ(j=0,i)S(i,j)。

       考慮第二類斯特林數的含義為將i個不同的數分成j個集合的方案數,那麼*j!就是講i個不同的數分到j個有序集合的方案數,那麼令Fi=Σ(j=0,i)S(i,j)*j!,則Fi表示將i個不同的數分到任意多個有序集合的方案數。考慮Fi的遞推式,列舉最後一個集合的大小k,那麼這個集合可以有1~i的大小,裡面的數可以有Ci,k中組合,所以Fi=Σ(k=1,i)Fi-kCi,k。注意到我們列舉的是左後一個集合的大小,那麼顯然如果考慮2^j這一項則相當於要乘上2,因此就是Fi=

Σ(k=1,i)Fi-kCi,k*2。

       上式中只需要令fi=Fi/i!,就轉化為卷積的形式;更進一步地,令f(x)=Σ(i=0,∞)Fi/i! x^i,令g(x)=Σ(i=1,∞)2/i! x^i,就有f(x)=f(x)*g(x)+1,故f(x)=1/(1-g(x)),直接上多項式求逆即可。

AC程式碼如下:

#include<iostream>
#include<cstdio>
#include<cstring>
#define ll long long
#define mod 998244353
#define N 270005
using namespace std;
int mx;
int n,m,inv[N],a[N],b[N],c[N],pos[N],na[N],w[2][N];
int ksm(int x,int y){
	int t=1;
	for (; y; y>>=1,x=(ll)x*x%mod) if (y&1) t=(ll)t*x%mod;
	return t;
}
void pwk(int n){
	int i,x=ksm(3,(mod-1)/n);
	w[0][0]=w[1][0]=1;
	for (i=1; i<n; i++) w[0][i]=w[1][n-i]=(ll)w[0][i-1]*x%mod;
	for (i=0; i<n; i++){
		pos[i]=pos[i>>1]>>1;
		if (i&1) pos[i]|=n>>1;
	}
}
void fnt(int *a,int n,int flag){if(n>mx)mx=n;
	int i,j,k,l,x,u,v;
	for (i=0; i<n; i++) na[pos[i]]=a[i];
	memcpy(a,na,sizeof(int)*n);
	for (k=1; k<n; k<<=1)
		for (i=0,x=n/k>>1; i<n; i+=k<<1)
			for (j=i,l=0; j<i+k; j++,l+=x){
				u=a[j]; v=(ll)a[j+k]*w[flag][l]%mod;
				a[j]=(u+v)%mod; a[j+k]=(u-v+mod)%mod;
			}
	if (flag){
		x=ksm(n,mod-2);
		for (i=0; i<n; i++) a[i]=(ll)a[i]*x%mod;
	}
}
void solve_inv(int *a,int *b,int n){
	if (n==1){
		b[0]=ksm(a[0],mod-2); return;
	}
	int i; solve_inv(a,b,n>>1);
	memcpy(c,a,sizeof(int)*n); memset(c+n,0,sizeof(int)*n);
	pwk(n<<1);
	fnt(b,n<<1,0); fnt(c,n<<1,0);
	for (i=0; i<(n<<1); i++) b[i]=(2-(ll)b[i]*c[i]%mod+mod)*b[i]%mod;
	fnt(b,n<<1,1); memset(b+n,0,sizeof(int)*n);
}
int main(){
	int i,n; scanf("%d",&n);
	inv[0]=inv[1]=a[0]=m=1;
	while (m<=n) m<<=1;
	for (i=2; i<=n; i++) inv[i]=mod-(ll)inv[mod%i]*(mod/i)%mod;
	for (i=3; i<=n; i++) inv[i]=(ll)inv[i-1]*inv[i]%mod;
	for (i=1; i<=n; i++) a[i]=((mod-inv[i])<<1)%mod;
	solve_inv(a,b,m);
	int ans=b[n];
	for (i=n; i; i--) ans=((ll)ans*i+b[i-1])%mod;
	printf("%d\n",ans);
	return 0;
}


by lych

2016.5.27