1. 程式人生 > 實用技巧 >兩個多項式的卷積【NTT】

兩個多項式的卷積【NTT】

題意

2020計蒜之道決賽:

給出兩個多項式 \(A(x)\)\(B(x)\)

\[A(x)=a_0+a_1x+a_2x^2+a_3x^3+\cdots +a_nx^n\\ B(x)=b_0+b_1x+b_2x^2+b_3x^3+\cdots +b_nx^n \]

\(C(x)\) 為上述兩個多項式的卷積:

\[C(x)=A(x)B(x)=c_0+c_1x+c_2x^2\cdots+c_{2n}x^{2n} \]

現有 \(m\) 次操作,每次操作可能查詢 \(\sum_{i=l}^{r}{c_i}\) ,也可能修改 \(A(x)\) 中的某個係數。

具體如下:

1 l r:代表查詢 \(\sum_{i=l}^{r}{c_i}\)

2 p q:表示把 \(A(x)\)\(x^p\) 的係數增加 \(q\)

\(1\leq n \leq 5000,-10^5\leq a_i \leq 10^5,-10^5\leq b_i\leq 10^5,0\leq p\leq n,-10^5\leq q \leq 10^5\)

輸出結果對 \(998244353\) 取模。

分析

對於某段區間的查詢,可以轉化為對兩個字首和的查詢。但對於第二種操作,如果每次修改之後直接做 \(NTT\) ,複雜度為 \(O(mn\log n)\)

對於這兩個操作而言,一種是“單次優,但數量太大”,一種是“單次劣,但是不限於操作次數”。我們可以減少 \(NTT\) 的次數,同時控制第一種演算法記錄的量不要太大。

維護一個大小為 \(S\) 的集合,每次來一個修改就把這次的修改資訊記錄到集合中,當集合大小增長到一定的閥值,就做一次 \(NTT\) ,同時把集合清空。而一次查詢的答案,等於上次做完 \(NTT\) 的時候的答案,再加上集合中記錄的修改操作對當前查詢的影響,通過合理控制 \(S\) 的大小,可以做到既可以不讓集合太大,又可以減小 \(NTT\) 的次數。

複雜度為:\(O(mS+\frac{m}{S}n\log n)\),通過均值不等式可知:\(S=\sqrt{n\log n}\) 時,複雜度最優秀。

程式碼

#include <bits/stdc++.h>

using namespace std;
typedef long long ll;
const int mod=998244353;
const int N=5100;
const int g=3;
ll A[N<<2],B[N<<2],C[N<<2],pre[N<<2],a[N<<2];
int rev[N<<2],S[N];
int pos[N],w[N],num;
ll power(ll x,ll y)
{
    ll res=1;
    x%=mod;
    while(y)
    {
        if(y&1)
            res=res*x%mod;
        x=x*x%mod;
        y>>=1;
    }
    return res;
}
void NTT(ll *pn,int len,int f)
{
    for(int i=0;i<len;i++)
        if(i<rev[i]) swap(pn[i],pn[rev[i]]);
    for(int i=1;i<len;i<<=1)
    {
        ll wn=power(g,1LL*(mod-1)/(2LL*i));
        if(f==-1) wn=power(wn,mod-2);
        for(int j=0,d=(i<<1);j<len;j+=d)
        {
            ll w=1;
            for(int k=0;k<i;k++)
            {
                ll u=pn[j+k],v=w*pn[j+k+i]%mod;
                pn[j+k]=(u+v)%mod,pn[j+k+i]=((u-v)%mod+mod)%mod;
                w=wn*w%mod;
            }
        }
    }
    if(f==-1)
    {
        ll inv=power(1LL*len,mod-2);
        for(int i=0;i<len;i++)
            pn[i]=pn[i]*inv%mod;
    }
}
void dontt(int len)
{
    for(int i=0;i<len;i++)
        A[i]=a[i];//因為NTT會改變原有係數的位置和值
    NTT(A,len,1);
    for(int i=0;i<len;i++)
        C[i]=A[i]*B[i]%mod;
    NTT(C,len,-1);
    for(int i=1;i<len;i++)
        C[i]=(C[i-1]+C[i])%mod;
}
ll cal(int x)
{
    if(x<0) return 0;
    ll res=C[x];
    for(int i=1;i<=num;i++)
    {
        int t=x-pos[i];
        if(t>=0)//確定B需要的字首和最大下標
            res=(res+(pre[t]*w[i]%mod)+mod)%mod;
    }
    return res;
}
int main()
{
    int n,m,op,x,y;
    scanf("%d",&n);
    for(int i=0;i<=n;i++)
    {
        scanf("%lld",&a[i]);
        a[i]=(a[i]+mod)%mod;
    }
    for(int i=0;i<=n;i++)
    {
        scanf("%lld",&B[i]);
        B[i]=(B[i]+mod)%mod;
        if(i==0)
            pre[i]=B[i];
        else
            pre[i]=(pre[i-1]+B[i])%mod;
    }
    int cnt=0,len=1;
    while(len<=2*n)//注意len不要開小了
    {
        len<<=1;
        cnt++;
    }
    for(int i=n+1;i<len;i++)
        pre[i]=(pre[i-1]+B[i])%mod;
    for(int i=0;i<len;i++)
        rev[i]=(rev[i>>1]>>1)|((i&1)<<(cnt-1));
    NTT(B,len,1);
    dontt(len);
    int mx=(int)sqrt(1.0*n*log2(1.0*n));
    num=0;
    scanf("%d",&m);
    while(m--)
    {
        scanf("%d%d%d",&op,&x,&y);
        if(op==1)
            printf("%lld\n",(cal(y)-cal(x-1)+mod)%mod);
        else
        {
            pos[++num]=x;//記錄修改的位置
            a[x]=(a[x]+y+mod)%mod;
            w[num]=y;//
            if(num>=mx)
            {
                dontt(len);
                num=0;
            }
        }
    }
    return 0;
}