兩個多項式的卷積【NTT】
阿新 • • 發佈:2020-11-16
題意
2020計蒜之道決賽:
給出兩個多項式 \(A(x)\) 和 \(B(x)\) :
\[A(x)=a_0+a_1x+a_2x^2+a_3x^3+\cdots +a_nx^n\\ B(x)=b_0+b_1x+b_2x^2+b_3x^3+\cdots +b_nx^n \]\(C(x)\) 為上述兩個多項式的卷積:
\[C(x)=A(x)B(x)=c_0+c_1x+c_2x^2\cdots+c_{2n}x^{2n} \]現有 \(m\) 次操作,每次操作可能查詢 \(\sum_{i=l}^{r}{c_i}\) ,也可能修改 \(A(x)\) 中的某個係數。
具體如下:
1 l r:代表查詢 \(\sum_{i=l}^{r}{c_i}\)
2 p q:表示把 \(A(x)\) 中 \(x^p\) 的係數增加 \(q\)
\(1\leq n \leq 5000,-10^5\leq a_i \leq 10^5,-10^5\leq b_i\leq 10^5,0\leq p\leq n,-10^5\leq q \leq 10^5\)
輸出結果對 \(998244353\) 取模。
分析
對於某段區間的查詢,可以轉化為對兩個字首和的查詢。但對於第二種操作,如果每次修改之後直接做 \(NTT\) ,複雜度為 \(O(mn\log n)\)。
對於這兩個操作而言,一種是“單次優,但數量太大”,一種是“單次劣,但是不限於操作次數”。我們可以減少 \(NTT\) 的次數,同時控制第一種演算法記錄的量不要太大。
維護一個大小為 \(S\) 的集合,每次來一個修改就把這次的修改資訊記錄到集合中,當集合大小增長到一定的閥值,就做一次 \(NTT\) ,同時把集合清空。而一次查詢的答案,等於上次做完 \(NTT\) 的時候的答案,再加上集合中記錄的修改操作對當前查詢的影響,通過合理控制 \(S\) 的大小,可以做到既可以不讓集合太大,又可以減小 \(NTT\) 的次數。
複雜度為:\(O(mS+\frac{m}{S}n\log n)\),通過均值不等式可知:\(S=\sqrt{n\log n}\) 時,複雜度最優秀。
程式碼
#include <bits/stdc++.h> using namespace std; typedef long long ll; const int mod=998244353; const int N=5100; const int g=3; ll A[N<<2],B[N<<2],C[N<<2],pre[N<<2],a[N<<2]; int rev[N<<2],S[N]; int pos[N],w[N],num; ll power(ll x,ll y) { ll res=1; x%=mod; while(y) { if(y&1) res=res*x%mod; x=x*x%mod; y>>=1; } return res; } void NTT(ll *pn,int len,int f) { for(int i=0;i<len;i++) if(i<rev[i]) swap(pn[i],pn[rev[i]]); for(int i=1;i<len;i<<=1) { ll wn=power(g,1LL*(mod-1)/(2LL*i)); if(f==-1) wn=power(wn,mod-2); for(int j=0,d=(i<<1);j<len;j+=d) { ll w=1; for(int k=0;k<i;k++) { ll u=pn[j+k],v=w*pn[j+k+i]%mod; pn[j+k]=(u+v)%mod,pn[j+k+i]=((u-v)%mod+mod)%mod; w=wn*w%mod; } } } if(f==-1) { ll inv=power(1LL*len,mod-2); for(int i=0;i<len;i++) pn[i]=pn[i]*inv%mod; } } void dontt(int len) { for(int i=0;i<len;i++) A[i]=a[i];//因為NTT會改變原有係數的位置和值 NTT(A,len,1); for(int i=0;i<len;i++) C[i]=A[i]*B[i]%mod; NTT(C,len,-1); for(int i=1;i<len;i++) C[i]=(C[i-1]+C[i])%mod; } ll cal(int x) { if(x<0) return 0; ll res=C[x]; for(int i=1;i<=num;i++) { int t=x-pos[i]; if(t>=0)//確定B需要的字首和最大下標 res=(res+(pre[t]*w[i]%mod)+mod)%mod; } return res; } int main() { int n,m,op,x,y; scanf("%d",&n); for(int i=0;i<=n;i++) { scanf("%lld",&a[i]); a[i]=(a[i]+mod)%mod; } for(int i=0;i<=n;i++) { scanf("%lld",&B[i]); B[i]=(B[i]+mod)%mod; if(i==0) pre[i]=B[i]; else pre[i]=(pre[i-1]+B[i])%mod; } int cnt=0,len=1; while(len<=2*n)//注意len不要開小了 { len<<=1; cnt++; } for(int i=n+1;i<len;i++) pre[i]=(pre[i-1]+B[i])%mod; for(int i=0;i<len;i++) rev[i]=(rev[i>>1]>>1)|((i&1)<<(cnt-1)); NTT(B,len,1); dontt(len); int mx=(int)sqrt(1.0*n*log2(1.0*n)); num=0; scanf("%d",&m); while(m--) { scanf("%d%d%d",&op,&x,&y); if(op==1) printf("%lld\n",(cal(y)-cal(x-1)+mod)%mod); else { pos[++num]=x;//記錄修改的位置 a[x]=(a[x]+y+mod)%mod; w[num]=y;// if(num>=mx) { dontt(len); num=0; } } } return 0; }