1. 程式人生 > >[UOJ182]a^-1 + b problem

[UOJ182]a^-1 + b problem

truct 取值 void n+1 == ++ ems problem string.h

$\newcommand{\align}[1]{\begin{align*}#1\end{align*}}$做這題需要一個前置知識:多項式的多點求值

多項式的多點求值:給定多項式$f(x)$和$x_{1\cdots n}$,要求出$f(1)\cdots f(n)$

首先,我們可以找到$g_i(x)$使得$f(x)=(x-x_i)g_i(x)+C$(就是把$f(x)$對$x-x_i$取模),當$x=x_i$,我們得到$f(x_i)=C$,即$f(x_i)=\left.f(x)\%(x-x_i)\right|_{x=x_i}$,所以我們要求的是$f(x)\%(x-x_i)$,直接對$n$個$x_i$暴力求是$O(n^2\log_2n)$的,比暴力還慢,但一個很顯然的事實是:如果$g(x)=h(x)r(x)$,那麽$f(x)\%g(x)\%h(x)=f(x)\%h(x)$,所以我們這樣分治求解:如果要求出$f(x)$在$x_{l\cdots r}$的取值,那麽就遞歸計算$\align{f(x)\%\prod\limits_{i=l}^r(x-x_i)}$在$x_{l\cdots mid}$和$x_{mid+1\cdots r}$的取值,因為有取模,所以$f(x)$的次數被降了下來,總時間復雜度$T(n)=2T\left(\dfrac n2\right)+O(n\log_2n)=O(n\log_2^2n)$,註意要用分治FFT預處理出$\align{\prod\limits_{i=l}^r(x-x_i)}$,時間復雜度也是$O(n\log_2^2n)$

然後是這道題,因為是全局操作,所以我們定義$f_i(x)$表示經過$i$次操作後,原來的$x$會變成$f_i(x)$,每次操作要麽是將$f(x)$加上一個常數,要麽是把它取倒數,所以它的形式肯定是$f(x)=\dfrac{ax+b}{cx+d}=p+\dfrac q{x+t}$($c=0$要特殊處理)

所以我們要求的答案是$\align{\sum\limits_{i=1}^nf(x_i)}$,展開得到$\align{pn+q\sum\limits_{i=1}^n\dfrac1{x_i+t}}$,在這個式子中,$x_i$是常數,而$t$隨著修改變化($m$個取值),所以我們把它看成關於$t$的函數$\align{g(t)=\sum\limits_{i=1}^n\dfrac1{x_i+t}}=\dfrac{\sum\limits_{i=1}^n\prod\limits_{j\ne i}(x_j+t)}{\prod\limits_{i=1}^n(x_i+t)}$,分母可以分治FFT算,分子是分母的導數,算出來後直接多點求值就做完了...

註意:凡是涉及分治FFT,需要new內存的,一定要註意不能訪問超限,這時assert就派上用場了>_<

#include<stdio.h>
#include<string.h>
#include<assert.h>
const int mod=998244353,maxn=262144;
typedef long long ll;
int mul(int a,int b){return a*(ll)b%mod;}
int ad(int a,int b){return(a+b)%mod;}
int de(int a,int b){return(a-b)%mod;}
void swap(int&a,int&b){
	int c=a;
	a=b;
	b=c;
}
int max(int a,int b){return a>b?a:b;}
int pow(int a,int b){
	int s=1;
	while(b){
		if(b&1)s=mul(s,a);
		a=mul(a,a);
		b>>=1;
	}
	return s;
}
int rev[maxn],N,iN;
void pre(int n){
	int i,k;
	for(N=1,k=0;N<n;N<<=1)k++;
	for(i=0;i<N;i++)rev[i]=(rev[i>>1]>>1)|((i&1)<<(k-1));
	iN=pow(N,mod-2);
}
void ntt(int*a,int on){
	int i,j,k,t,w,wn;
	for(i=0;i<N;i++){
		if(i<rev[i])swap(a[i],a[rev[i]]);
	}
	for(i=2;i<=N;i<<=1){
		wn=pow(3,(on==1)?(mod-1)/i:(mod-1-(mod-1)/i));
		for(j=0;j<N;j+=i){
			w=1;
			for(k=0;k<i>>1;k++){
				t=mul(w,a[i/2+j+k]);
				a[i/2+j+k]=de(a[j+k],t);
				a[j+k]=ad(a[j+k],t);
				w=mul(w,wn);
			}
		}
	}
	if(on==-1){
		for(i=0;i<N;i++)a[i]=mul(a[i],iN);
	}
}
int t0[maxn];
void getinv(int*a,int*b,int n){
	if(n==1){
		b[0]=pow(a[0],mod-2);
		return;
	}
	int i;
	getinv(a,b,n>>1);
	pre(n<<1);
	memset(t0,0,N<<2);
	memcpy(t0,a,n<<2);
	ntt(t0,1);
	ntt(b,1);
	for(i=0;i<N;i++)b[i]=mul(b[i],2-mul(b[i],t0[i]));
	ntt(b,-1);
	for(i=n;i<N;i++)b[i]=0;
}
int ta[maxn],tb[maxn],tc[maxn];
void add(int*a,int n,int*b,int m,int*c,int&k){
	k=max(n,m);
	for(int i=0;i<=k;i++)tc[i]=ad(a[i],b[i]);
	while(k!=0&&tc[k]==0)k--;
	memcpy(c,tc,(k+1)<<2);
}
void dec(int*a,int n,int*b,int m,int*c,int&k){
	k=max(n,m);
	for(int i=0;i<=k;i++)tc[i]=de(a[i],b[i]);
	while(k!=0&&tc[k]==0)k--;
	memcpy(c,tc,(k+1)<<2);
}
void reverse(int*a,int n){
	for(int i=0;i<=n>>1;i++)swap(a[i],a[n-i]);
}
void mul(int*a,int n,int*b,int m,int*c,int&k){
	int i;
	k=n+m;
	pre(k+1);
	memset(ta,0,N<<2);
	memset(tb,0,N<<2);
	memcpy(ta,a,(n+1)<<2);
	memcpy(tb,b,(m+1)<<2);
	ntt(ta,1);
	ntt(tb,1);
	for(i=0;i<N;i++)tc[i]=mul(ta[i],tb[i]);
	ntt(tc,-1);
	memcpy(c,tc,(k+1)<<2);
}
int t1[maxn];
void div(int*a,int n,int*b,int m,int*c,int&k){
	if(n<m){
		k=0;
		return;
	}
	int i,rn;
	for(rn=1;rn<n-m+1;rn<<=1);
	memset(ta,0,rn<<3);
	memset(tb,0,rn<<3);
	memcpy(ta,a,(n+1)<<2);
	memcpy(tb,b,(m+1)<<2);
	reverse(tb,m);
	for(i=rn;i<=m;i++)tb[i]=0;
	memset(t1,0,rn<<3);
	getinv(tb,t1,rn);
	pre(rn<<1);
	reverse(ta,n);
	for(i=rn;i<=n;i++)ta[i]=0;
	ntt(ta,1);
	ntt(t1,1);
	for(i=0;i<N;i++)tc[i]=mul(ta[i],t1[i]);
	ntt(tc,-1);
	k=n-m;
	reverse(tc,k);
	while(k!=0&&tc[k]==0)k--;
	memcpy(c,tc,(k+1)<<2);
}
int len;
void modulo(int*a,int n,int*b,int m,int*c,int&k){
	if(n<m){
		k=n;
		memcpy(c,a,(n+1)<<2);
		return;
	}
	div(a,n,b,m,t1,k);
	mul(t1,k,b,m,t1,k);
	//assert(max(n,k)<=len);
	dec(a,n,t1,k,c,k);
}
struct frac{//(ax+b)/(cx+d)
	int a,b,c,d;
	void add(int k){
		a=ad(a,mul(c,k));
		b=ad(b,mul(d,k));
	}
	void inv(){
		swap(a,c);
		swap(b,d);
	}
}fr[60010];
int x[100010],op[60010],v[60010],ti[60010],*tr[240010],M;
void build(int l,int r,int x){
	if(l==r){
		tr[x]=new int[2];
		tr[x][0]=-ti[l];
		tr[x][1]=1;
		return;
	}
	int mid=(l+r)>>1;
	build(l,mid,x<<1);
	build(mid+1,r,x<<1|1);
	tr[x]=new int[r-l+2];
	mul(tr[x<<1],mid-l+1,tr[x<<1|1],r-mid,tr[x],x);
}
void solve(int*f,int n,int l,int r,int x,int*ans){
	int mid=(l+r)>>1,*now;
	now=new int[r-l+1];
	len=r-l;
	modulo(f,n,tr[x],r-l+1,now,n);
	if(l==r){
		ans[l]=now[0];
		return;
	}
	solve(now,n,l,mid,x<<1,ans);
	solve(now,n,mid+1,r,x<<1|1,ans);
}
int t2[maxn],t3[maxn];
int*solve2(int l,int r){
	int mid,*res,*L,*R,len;
	res=new int[r-l+2];
	if(l==r){
		res[1]=1;
		res[0]=x[l];
		return res;
	}
	mid=(l+r)>>1;
	L=solve2(l,mid);
	R=solve2(mid+1,r);
	mul(L,mid-l+1,R,r-mid,res,len);
	return res;
}
int ans1[60010],ans2[60010],ans[60010],up[100010];
int main(){
	int n,m,i,p,q,del,*res;
	scanf("%d%d",&n,&m);
	for(i=1;i<=n;i++)scanf("%d",x+i);
	fr->a=fr->d=1;
	fr->b=fr->c=0;
	for(i=1;i<=m;i++){
		fr[i]=fr[i-1];
		scanf("%d",op+i);
		if(op[i]==1){
			scanf("%d",v+i);
			fr[i].add(v[i]);
		}else
			fr[i].inv();
		if(op[i]==2){
			M++;
			ti[M]=mul(fr[i].d,pow(fr[i].c,mod-2));
		}
	}
	del=0;
	for(i=1;i<=n;i++)ans[0]=ad(ans[0],x[i]);
	if(M==0){
		for(i=1;i<=m;i++){
			del=ad(del,v[i]);
			printf("%d\n",ad(ans[0],mul(del,n)));
		}
		return 0;
	}
	build(1,M,1);
	res=solve2(1,n);
	for(i=1;i<=n;i++)up[i-1]=mul(res[i],i);
	solve(up,n-1,1,M,1,ans1);
	solve(res,n,1,M,1,ans2);
	M=del=0;
	for(i=1;i<=m;i++){
		if(op[i]==1){
			del=ad(del,v[i]);
			printf("%d\n",ad(ad(ans[M],mul(n,del)),mod));
		}else{
			M++;
			if(fr[i].c==0){
				printf("%d\n",ans[M]=ad(mul(ad(mul(fr[i].a,ans[0]),mul(fr[i].b,n)),pow(fr[i].d,mod-2)),mod));
				continue;
			}
			p=mul(fr[i].a,pow(fr[i].c,mod-2));
			q=mul(de(mul(fr[i].b,fr[i].c),mul(fr[i].a,fr[i].d)),pow(mul(fr[i].c,fr[i].c),mod-2));
			ans[M]=ad(mul(p,n),mul(q,mul(ans1[M],pow(ans2[M],mod-2))));
			printf("%d\n",ad(ans[M],mod));
			del=0;
		}
	}
}

[UOJ182]a^-1 + b problem