拉格朗日插值 & FFT & NTT 及其應用
進軍多項式。
1. 拉格朗日插值
1.1. 普通插值
首先給出公式:
\[\large F(x)=\sum_{k=1}^n(y_k\prod_{i=1,i\neq k}^n \dfrac{x-x_i}{x_k-x_i}) \]對於每對點值 \((x_k,y_k)\),我們需要構造出一個函式 \(G(x)\),使得其在 \(x=x_k\) 處的取值為 \(y_k\),其餘處取值為 \(0\)。
首先建構函式 \(D(x)=\prod_{i=1,i\neq k}^n x-x_i\)。顯然當 \(i\neq k\) 時,有 \(D(x_i)=0\)。但是現在我們不能保證 \(D(x_k)=y_k\)。為了使 \(D(x_k)=y_k\)
通常情況下,題目會要求我們求出 \(F(x)\) 在給定某個 \(x\) 處的取值,此時我們不把 \(x\) 看做函式的一個元,而是直接將 \(x\) 帶入上式即可。時間複雜度為 \(\mathcal{O}(n^2)\),程式碼見例題 I。
1.2. 連續取值插值
很多情況下,我們求出的點值 \(x_i\) 滿足 \(x_i=i\),即 \(x_i\) 是連續的。此時我們重新寫一下公式:
\[\sum_{k=1}^n(y_k\prod_{i=1,i\neq k}^n \dfrac{x-i}{k-i}) \]記 \(p_i=\prod_{j=1}^ix-i\)
預處理階乘就可以線性插值,應用見例題 II。
1.3. 求 \(F(x)\) 各項係數
建構函式 \(D(x)=\prod_{i=1}^n(x-x_i)\),設 \(d_k=y_i\prod_{i=1,i\neq k}^n\dfrac{1}{x_k-x_i}\),則 \(F(x)=\sum_{k=1}^n \dfrac{d_kD(x)}{x-x_k}\)。注意到 \(D(x)\)
1.4. 例題
I. P4781 【模板】拉格朗日插值
板子題。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const ll mod=998244353;
const int N=2e3+5;
ll ksm(ll a,ll b){
ll s=1;
while(b){
if(b&1)s=s*a%mod;
a=a*a%mod,b>>=1;
} return s;
}
int n,k,x[N],y[N],ans;
int main(){
cin>>n>>k;
for(int i=1;i<=n;i++)cin>>x[i]>>y[i];
for(ll i=1,s1=1,s2=1;i<=n;i++,s1=s2=1){
for(int j=1;j<=n;j++)if(i!=j)s1=s1*(k-x[j])%mod,s2=s2*(x[i]-x[j])%mod;
ans=(ans+y[i]*s1%mod*ksm(s2,mod-2))%mod;
} cout<<(ans%mod+mod)%mod<<endl;
return 0;
}
II. CF622F The Sum of the k-th Powers
經典題。一個結論是自然數 \(k\) 次方和是 \(k+1\) 次多項式,那麼只需要帶 \((i,i^k)\ (i\in [0,k+1])\) 插值即可。注意到當 \(i=0\) 時對答案無貢獻,所以在插值的時候可以跳過(但是預處理 \(p,s\) 的時候仍應考慮 \(i=0\) 時 \(n-i\) 的這個 \(n\))。
時間複雜度 \(\mathcal{O}(k\log k)\),使用線性篩篩 \(i^k\) 可以做到線性。
#include <bits/stdc++.h>
using namespace std;
#define ll long long
const ll mod=1e9+7;
const int N=1e6+5;
ll ksm(ll a,ll b){
ll s=1;
while(b){
if(b&1)s=s*a%mod;
a=a*a%mod,b>>=1;
} return s;
} ll inv(ll x){return ksm(x,mod-2);}
ll n,k,ans;
ll p[N],s[N],fc[N];
int main(){
cin>>n>>k,s[k+2]=1;
for(int i=0;i<=k+1;i++)fc[i]=i?fc[i-1]*i%mod:1,p[i]=i?p[i-1]*(n-i)%mod:n;
for(int i=k+1;~i;i--)s[i]=i==k+1?n-i:s[i+1]*(n-i)%mod;
for(ll i=1,res=0;i<=k+1;i++){
res=(res+ksm(i,k))%mod;
ans=(ans+p[i-1]*s[i+1]%mod*res%mod*ksm(fc[i]*((k-i)&1?1:-1)*fc[k+1-i]%mod,mod-2))%mod;
} cout<<(ans%mod+mod)%mod<<endl;
return 0;
}
2. FFT
對於一個 \(n\) 次多項式 \(F(x)=a_0+a_1x+a_2x^2+\cdots+a_nx^n\),我們可以用 \(n+1\) 個點值 \((x_k,y_k)\) 唯一確定該多項式。即 \(y_k=\sum_{i=0}^na_ix_k^i\)(注意下文的 \(n\) 表示不小於多項式次數 \(+1\) 的最小的 \(2\) 的冪)。設 \(A=F\times G\),注意到 \(A(x)=F(x)G(x)\),其中 \(x\) 是一個確定的值。因此,我們只需要將一個多項式快速(\(n\log n\))轉點值,再快速轉成係數表示,就可以做到時間複雜度 \(n\log n\) 的多項式乘法。
點值的取法很有講究,高明的方法能夠極大化地減小時間複雜度。這裡我們採用 \(n\) 次單位根 \(\omega_n\),並主要利用以下性質:
- \(\omega_n^k=\omega_{2n}^{2k}\)。
- \(\omega_n^k=-\omega_n^{k+n/2}\)。
- \(\omega_n^n=1\)。
- \(\omega_n=\cos(\dfrac{2\pi}{n})+i\sin(\dfrac{2\pi}{n})\),從而計算 \(n\) 次單位根。
當 \(n=4\) 時,\(F(x)=a_0+a_1x+a_2x^2+a_3x^3=(a_0+a_2x^2)+x(a_1+a_3x^2)\)。設 \(L(x)=a_0+a_2x\),\(R(x)=a_1+a_3x\),那麼 \(F(x)=L(x^2)+xR(x^2)\)。將單位根 \(\omega_n\) 帶入,那麼當 \(k<\dfrac{n}{2}\) 時,\(F(\omega_n^k)=L(\omega_{n/2}^k)+\omega_{n}^kR(\omega_{n/2}^k)\),\(F(\omega_n^{k+n/2})=L(\omega_{n/2}^k)-\omega_n^kR(\omega_{n/2}^k)\)。注意到兩式只有一個正負號不同。因此,如果我們已經知道了 \(L,R\) 在 \(\omega_{n/2}^i,\ i\in[0,\dfrac{n}{2})\) 處的取值,那麼我們就可以線上性時間內求出 \(F\) 在 \(\omega_n^i, \ i\in[0,n)\) 處的取值。考慮遞迴樹的每一層都是線性的,因此總複雜度為 \(n\log n\)。
遞迴處理很慢,於是我們使用迭代:考慮係數 \(a_i\) 向 \(L\) 分治為 \(0\),向 \(R\) 分治為 \(1\),那麼顯然 \(a_i\) 最終形成的 “分治序列” 形成的二進位制數倒過來就是 \(i\) 的二進位制表示。考慮求出 \(r_i\) 表示 \(i\) 二進位制翻轉後得到的數。假設 \(r_0,r_1,\cdots,r_{i-1}\) 都已經求出 ,那麼有 \(r_i=\lfloor\dfrac{r_{\lfloor\dfrac{i}{2}\rfloor}}{2}\rfloor+\dfrac{n}{2}\times(i\bmod 2)\)。左邊是 \(i\) 不考慮最低位(即假設其為 \(0\))時二進位制翻轉得到的數,然後再考慮 \(i\) 的最低位的影響即可。然後對於每一對無序對 \((i,r_i)\),將 \(a_i\) 與 \(a_{r_i}\) 交換,那麼最終的分治樹的形態類似線段樹,直接從最底層向上迭代即可。該操作被稱為蝴蝶迭代。
卡常技巧:
- 不要寫建構函式。
- 過載加減乘運算子放到
struct
裡面。
3. NTT
由於複數運算很慢,而通常情況下我們是在模意義下進行多項式運算,所以當模數取一些特殊值時,我們可以用 \(\mathbb{Z}\) 中的數 \(g\) 代替單位根的複數運算。具體地,\(g\) 需要是模 \(p\) 的原根。
鴿著。
4. FFT 優化字串匹配
Trick 1:翻轉一個序列常常可以使關於它的某些計算變成卷積的形式。
對於一個文字串 \(s\) 與匹配串 \(t\)(下標從 \(0\) 開始),設它們的長度分別為 \(n\),\(m\)。稱它們在位置 \(p\) 匹配,當且僅當對於任意 \(i\in[0,m)\),有 \(s_{p-m+i+1}=t_i\)。不難發現它的充分條件為 \(\sum_{i=0}^{m-1}(s_{p-m+i+1}-t_i)^2=0\)。展開,得到 \(\sum_{i=0}^{m-1}(s^2_{p-m+i+1}+t^2_{i}-2s_{p-m+i+1}t_i)=0\)。注意到前面兩項容易預處理得到,但後面一項同時關於兩個字串,比較麻煩。又發現它的形式類似卷積,但又不是卷積:翻轉字串 \(t\) 即可。因此柿子變為 \(\sum_{i=0}^{m-1}(s^2_{p-m+i+1}+t_{m-i-1}^2-2s_{p-m+i+1}t_{m-i-1})\),後面一項即 \(2\sum_{0\leq i<m,i+j=p}s_jt_i\),FFT 計算即可。
對於有萬用字元的字串,我們不妨將該位置上的值設為 \(0\),然後乘到上面的柿子裡去,即 \(\sum_{i=0}^{m-1}(s_{p-m+i+1}-t_i)^2s_{p-m+i+1}t_i\)。化簡得到 \(\sum_{0\leq i<m,i+j=p}(s^3_jt_i-2s^2_jt^2_i+s_jt^3_i)\),做 6 次 DFT + 1 次 IDFT 即可。
Trick 2:一般情況下,上述方法足夠應付多數題目。但是如果萬惡的出題人卡了 FFT 精度就涼涼了。因此,為了保險,應儘量使用 NTT。但是 NTT 也有一個致命問題:如果計算出來的值剛好是 \(998244353\) 的倍數,那麼就會在不匹配的地方判定匹配。看起來好像沒有解決辦法了?非也。在 P4173 殘缺的字串 這題的討論區 https://www.luogu.com.cn/discuss/show/303076,我找到了一個高明的手段解決這個問題:注意到上述柿子最大值可達到 \(3\times m\times (|\mathbb{\Sigma}|-1)^4\),如果 \(m\) 取 \(5\times 10^5\),字符集取大小 \(26\),那麼數量級為 \(1.5\times 10^6\times 25^4\approx 6\times 10^{11}\),顯然不太行。但是實際上我們沒必要將整個 \(s\) 和 \(t\) 乘進去,因為我們關心的只是某一位是否是萬用字元,而具體這一位是什麼並不重要。因此,我們令 \(S_i=[s_i\neq\texttt{*}]\),\(T_i=[t_i\neq\texttt{*}]\),只需要將 \(S,T\) 而非 \(s,t\) 乘入即可。 這時最大值僅為 \(3\times 5\times 10^5\times 25^2=9.375\times 10^8<998244353\),有驚無險地保證了正確性。當然,不同題目還應根據 \(m\) 和字符集大小的不同取值具體分析正確性。
也許看了上述分析的你認為:既然將 \(S,T\) 乘進去,那麼不如直接將是萬用字元的位置當做 \(0\) 來計算 \(\sum_{i=0}^{m-1}(s_{p-m+i+1}-t_i)^2\) 不就好了?非也。因為這樣在計算只關於某一個字串的項時,考慮不到另外一個字串的對應位置是否是萬用字元。因此,6 次 DFT 是逃不掉的。最終的柿子即為 \(\sum_{0\leq i<m,i+j=p}((s_j^2S_j)\times T_i-2(s_jS_j)\times (t_iT_i)+S_j\times (t^2_iT_i))\)。
例題 & 程式碼可以看 II.
5. FFT & NTT 例題
I. P3803 【模板】多項式乘法(FFT)
FFT:
#include <bits/stdc++.h>
using namespace std;
#define ld double
const int N=1<<21;
const ld Pi=acos(-1);
struct com{
ld x,y;
com operator + (com b){return (com){x+b.x,y+b.y};}
com operator - (com b){return (com){x-b.x,y-b.y};}
com operator * (com b){return (com){x*b.x-y*b.y,x*b.y+y*b.x};}
}a[N],b[N];
int n,m,lim=1,bit,r[N];
void FFT(com *a,int tp){
for(int i=0;i<lim;i++)if(i<r[i])swap(a[i],a[r[i]]);
for(int l=1;l<lim;l<<=1){
com wn={cos(Pi/l),tp*sin(Pi/l)};
for(int j=0;j<lim;j+=(l<<1)){
com w={1,0},x,y;
for(int k=0;k<l;k++,w=w*wn)
x=a[j+k],y=w*a[j+k+l],a[j+k]=x+y,a[j+k+l]=x-y;
}
}
}
int main(){
cin>>n>>m;
for(int i=0;i<=n;i++)scanf("%lf",&a[i].x);
for(int i=0;i<=m;i++)scanf("%lf",&b[i].x);
while(lim<=n+m)lim<<=1,bit++;
for(int i=1;i<lim;i++)r[i]=(r[i>>1]>>1)|((i&1)<<bit-1);
FFT(a,1),FFT(b,1);
for(int i=0;i<lim;i++)a[i]=a[i]*b[i];
FFT(a,-1);
for(int i=0;i<=n+m;i++)cout<<(int)(a[i].x/lim+0.5)<<" ";
return 0;
}
NTT:
#include <bits/stdc++.h>
using namespace std;
#define ll long long
#define ull unsigned long long
#define gc getchar()
inline int read(){
int x=0; char s=gc;
while(!isdigit(s))s=gc;
while(isdigit(s))x=s-'0',s=gc;
return x;
}
const ll mod=998244353;
const int N=1<<21;
ll ksm(ll a,ll b){
ll s=1;
while(b){
if(b&1)s=s*a%mod;
a=a*a%mod,b>>=1;
} return s;
} ll inv(ll x){return ksm(x,mod-2);}
const int G=3;
const int ivG=inv(3);
ll tr[N],lim=1,l;
ll n,m,f[N],g[N];
void NTT(ll *a,bool tp){
static ull f[N],w[N]; w[0]=1;
for(int i=0;i<lim;i++)f[i]=a[tr[i]];
for(int l=1;l<lim;l<<=1){
ll wn=ksm(tp?G:ivG,(mod-1)/(l+l));
for(int i=1;i<l;i++)w[i]=w[i-1]*wn%mod;
for(int i=0;i<lim;i+=l<<1){
for(int j=0;j<l;j++){
int y=w[j]*f[i|j|l]%mod;
f[i|j|l]=f[i|j]+mod-y,f[i|j]+=y;
}
} if(l==(1<<17))for(int i=0;i<lim;i++)f[i]%=mod;
}
if(!tp){
ll iv=inv(lim);
for(int i=0;i<lim;i++)a[i]=f[i]%mod*iv%mod;
} else for(int i=0;i<lim;i++)a[i]=f[i]%mod;
}
int main(){
cin>>n>>m;
for(int i=0;i<=n;i++)f[i]=read();
for(int i=0;i<=m;i++)g[i]=read();
while(lim<=n+m)lim<<=1,l++;
for(int i=1;i<lim;i++)tr[i]=(tr[i>>1]>>1)|((i&1)<<l-1);
NTT(f,1),NTT(g,1);
for(int i=0;i<lim;i++)f[i]=f[i]*g[i]%mod;
NTT(f,0);
for(int i=0;i<=n+m;i++)printf("%lld ",f[i]);
}
II. P4173 殘缺的字串
經典字串匹配題。
#include <bits/stdc++.h>
using namespace std;
typedef double db;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;
#define gc getchar()
#define pb push_back
#define mem(x,v,n) memset(x,v,sizeof(int)*n)
#define cpy(x,y,n) memcpy(x,y,sizeof(int)*n)
const ld Pi=acos(-1);
const ll mod=998244353;
inline int read(){
int x=0; char s=gc;
while(!isdigit(s))s=gc;
while(isdigit(s))x=x*10+s-'0',s=gc;
return x;
}
ll ksm(ll a,ll b){
ll s=1;
while(b){
if(b&1)s=s*a%mod;
a=a*a%mod,b>>=1;
}
return s;
}
ll inv(ll x){return ksm(x,mod-2);}
const int N=1<<19;
const ll G=3;
const ll ivG=inv(3);
int r[N],pren;
void pre(int n){
if(n==pren)return;
for(int i=1;i<n;i++)r[i]=(r[i>>1]>>1)|(i&1?n>>1:0);
}
void NTT(int *g,int n,bool op){
pre(n);
static ull f[N],w[N]; w[0]=1;
for(int i=0;i<n;i++)f[i]=g[r[i]];
for(int l=1;l<n;l<<=1){
ull wn=ksm(op?G:ivG,(mod-1)/(l+l));
for(int i=1;i<l;i++)w[i]=w[i-1]*wn%mod;
for(int i=0;i<n;i+=l<<1)
for(int j=0;j<l;j++){
int t=w[j]*f[i|j|l]%mod;
f[i|j|l]=f[i|j]+mod-t,f[i|j]+=t;
}
if(l==(1<<16))for(int i=0;i<n;i++)f[i]%=mod;
}
if(op)for(int i=0;i<n;i++)g[i]=f[i]%mod;
else{
ll iv=inv(n);
for(int i=0;i<n;i++)g[i]=f[i]%mod*iv%mod;
}
}
int n,m,lim=1,a[N],b[N],res[N];
int ans[N],cnt;
string A,B;
int main(){
cin>>n>>m>>A>>B,reverse(A.begin(),A.end());
while(lim<m)lim<<=1;
for(int i=0;i<n;i++)if(A[i]!='*')a[i]=(A[i]-'a')*(A[i]-'a');
for(int i=0;i<m;i++)if(B[i]!='*')b[i]=1;
NTT(a,lim,1),NTT(b,lim,1);
for(int i=0;i<lim;i++)res[i]=1ll*a[i]*b[i]%mod;
mem(a,0,lim),mem(b,0,lim);
for(int i=0;i<n;i++)if(A[i]!='*')a[i]=A[i]-'a';
for(int i=0;i<m;i++)if(B[i]!='*')b[i]=B[i]-'a';
NTT(a,lim,1),NTT(b,lim,1);
for(int i=0;i<lim;i++)res[i]=(res[i]-2ll*a[i]*b[i]%mod+mod)%mod;
mem(a,0,lim),mem(b,0,lim);
for(int i=0;i<n;i++)if(A[i]!='*')a[i]=1;
for(int i=0;i<m;i++)if(B[i]!='*')b[i]=(B[i]-'a')*(B[i]-'a');
NTT(a,lim,1),NTT(b,lim,1);
for(int i=0;i<lim;i++)res[i]=(res[i]+1ll*a[i]*b[i])%mod;
NTT(res,lim,0);
for(int i=n-1;i<m;i++)if(!res[i])ans[++cnt]=i-n+2;
cout<<cnt<<endl;
for(int i=1;i<=cnt;i++)cout<<ans[i]<<" ";
return 0;
}