1. 程式人生 > >任意模數NTT(學習筆記)

任意模數NTT(學習筆記)

F F T FFT 有時候會被卡精度?所以可能會有模數,有了模數以後就需要模數的原根。

原根是什麼?(留坑待填)

N T T

NTT 有很多種解決方法
1. 1. 特殊模數
( 2 k +
1 ) ( p 1 ) ( p
1 ) > D F T (2k+1)|(p−1),(p−1)>DFT的長度
,可以直接暴力求原根 g g ,用 g g 代替單位複數根

2. 2. 一般模數

  • 三模數法(9次DFT)
    如果這個模數的原根不好求,而模數又很大,可以用三個模數
    要求 m o d 1 × m o d 2 × m o d 3 n × p 2 mod1\times mod2\times mod3\ge n\times p^2
    常用 998244353 , 1004535809 , 469762049 998244353,1004535809,469762049 ,因為他們的原根都是 3 3
    對這幾個模數分別做 D F T DFT ,然後用中國剩餘定理合併,然後對原模數取模即可
    中國剩餘定理部分
    a n s c 1   m o d   m 1 ans\equiv c1\ mod\ m1
    a n s c 2   m o d   m 2 ans\equiv c2\ mod\ m2
    a n s c 3   m o d   m 3 ans\equiv c3\ mod\ m3
    如果直接合並的話會爆 l o n g   l o n g long\ long ,可以先合併兩個,再用奇技淫巧合並第三個
    合併前兩個:
    a n s ( c 1 × m 2 × i n v ( m 2 , m 1 ) + c 2 × m 1 × i n v ( m 1 , m 2 ) ) m o d   ( m 1 × m 2 ) ans\equiv (c1\times m2\times inv(m2,m1)+c2\times m1\times inv(m1,m2)) mod\ (m1\times m2)
    其中 i n v ( x , y ) inv(x,y) 表示 x x y y 取模的逆元。
    把上式化簡為 a n s C   m o d   M ans\equiv C\ mod\ M
    a n s = x × M + C = y × m 3 + c 3 ans=x\times M+C=y\times m3+c3
    接下來很重要:求出 x   m o d   m 3 x\ mod\ m3 意義下的值:
    x ( c 3 C ) × M 1   m o d   m 3 x\equiv (c3-C)\times M^{-1}\ mod\ m3
    這樣就可以算出右半部分的值 q q ,令 x = k × m 3 + q x=k\times m3+q ,代入 a n s ans 得:
    a n s = k × m 1 × m 2 × m 3 + q × M + C ans=k\times m1\times m2\times m3+q\times M+C
    因為 a n s [ 0 , m 1 × m 2 × m 3 ) ans\in [0,m1\times m2\times m3) ,所以 k = 0 k=0 ,於是 a n s ans 就可直接計算。
    做9次DFT,常數極大
    有一道模板題luogu4245
    直接用三模數法,程式碼如下:
#include<iostream>
#include<cstdio>
#include<algorithm>
#include<cstring>
#include<cmath>
#define LL long long
#define maxn 400005
using namespace std;
const LL mod1=998244353,mod2=1004535809,mod3=469762049,g=3;
const LL M=1LL*mod1*mod2;

inline int rd(){
    int x=0,f=1;char c=' ';
    while(c<'0' || c>'9') f=c=='-'?-1:1,c=getchar();
    while(c<='9' && c>='0') x=x*10+c-'0',c=getchar();
    return x*f; 
}

int n,m,p,rev[maxn],limit=1,l;
LL a[3][maxn],b[3][maxn],ans[maxn];

inline LL mul(LL x,int k,LL MOD){
    LL ret=0;
    while(k){
        if(k&1) (ret+=x)%=MOD;
        (x+=x)%=MOD; k>>=1;
    } return ret%MOD;
}

inline LL qpow(LL x,int k,int MOD){
    LL ret=1;
    while(k){
        if(k&1) ret=ret*x%MOD;
        x=x*x%MOD; k>>=1;
    } return ret%MOD;
}

inline void NTT(LL *F,int type,int MOD){
    for(int i=0;i<limit;i++){
        F[i]%=MOD;
        if(i<rev[i]) swap(F[i],F[rev[i]]);
    }
    for(int mid=1;mid<limit;mid<<=1){
        LL Wn=qpow(g,type==1?(MOD-1)/(mid<<1):(MOD-1-(MOD-1)/(mid<<1)),MOD);//!
        for(int r=mid<<1,j=0;j<limit;j+=r){
            LL w=1;
            for(int k=0;k<mid;k++,w=w*Wn%MOD){
                LL x=F[j+k],y=w*F[j+mid+k]%MOD;
                F[j+k]=(x+y)%MOD; F[j+mid+k]=(x-y+MOD)%MOD;
            }
        }
    }
    if(type==-1){
        LL INV=qpow(limit,MOD-2,MOD);//除以limit
        for(int i=0;i<limit;i++) F[i]=F[i]*INV%MOD; 
    }
}

inline void CRT(){
    for(int i=0;i<limit;i++){
        LL tmp=0;
        (tmp+=mul