1. 程式人生 > 其它 >拉格朗日插值 & FFT & NTT 及其應用

拉格朗日插值 & FFT & NTT 及其應用

進軍多項式

進軍多項式。

1. 拉格朗日插值

1.1. 普通插值

首先給出公式:

\[\large F(x)=\sum_{k=1}^n(y_k\prod_{i=1,i\neq k}^n \dfrac{x-x_i}{x_k-x_i}) \]

對於每對點值 \((x_k,y_k)\),我們需要構造出一個函式 \(G(x)\),使得其在 \(x=x_k\) 處的取值為 \(y_k\),其餘處取值為 \(0\)

首先建構函式 \(D(x)=\prod_{i=1,i\neq k}^n x-x_i\)。顯然當 \(i\neq k\) 時,有 \(D(x_i)=0\)。但是現在我們不能保證 \(D(x_k)=y_k\)。為了使 \(D(x_k)=y_k\)

,我們只需要先將其除以 \(D(x_k)\),再乘以 \(y_k\) 即可,這就有了上面的拉格朗日插值公式

通常情況下,題目會要求我們求出 \(F(x)\) 在給定某個 \(x\) 處的取值,此時我們不把 \(x\) 看做函式的一個元,而是直接將 \(x\) 帶入上式即可。時間複雜度為 \(\mathcal{O}(n^2)\),程式碼見例題 I。

1.2. 連續取值插值

很多情況下,我們求出的點值 \(x_i\) 滿足 \(x_i=i\),即 \(x_i\) 是連續的。此時我們重新寫一下公式:

\[\sum_{k=1}^n(y_k\prod_{i=1,i\neq k}^n \dfrac{x-i}{k-i}) \]

\(p_i=\prod_{j=1}^ix-i\)

\(s_i=\prod_{j=i+1}^n x-i\),這些可以線性預處理,那麼上述柿子右邊就變成了

\[\dfrac{p_{k-1}s_{k+1}}{(k-1)!\times (-1)^{n-k}(n-k)!} \]

預處理階乘就可以線性插值,應用見例題 II。

1.3. 求 \(F(x)\) 各項係數

建構函式 \(D(x)=\prod_{i=1}^n(x-x_i)\),設 \(d_k=y_i\prod_{i=1,i\neq k}^n\dfrac{1}{x_k-x_i}\),則 \(F(x)=\sum_{k=1}^n \dfrac{d_kD(x)}{x-x_k}\)。注意到 \(D(x)\)

的各項係數可以在 \(n^2\) 的時間內暴力處理出來,而對於每個 \(k\),我們可以線性將 \(D(x)\) 除以一個一次多項式。最後加和即可。時間複雜度 \(\mathcal{O}(n^2)\)

1.4. 例題

I. P4781 【模板】拉格朗日插值

板子題。

#include <bits/stdc++.h>
using namespace std;

#define ll long long

const ll mod=998244353;
const int N=2e3+5;

ll ksm(ll a,ll b){
	ll s=1;
	while(b){
		if(b&1)s=s*a%mod;
		a=a*a%mod,b>>=1;
	} return s;
}

int n,k,x[N],y[N],ans;
int main(){
	cin>>n>>k;
	for(int i=1;i<=n;i++)cin>>x[i]>>y[i];
	for(ll i=1,s1=1,s2=1;i<=n;i++,s1=s2=1){
		for(int j=1;j<=n;j++)if(i!=j)s1=s1*(k-x[j])%mod,s2=s2*(x[i]-x[j])%mod;
		ans=(ans+y[i]*s1%mod*ksm(s2,mod-2))%mod;
	} cout<<(ans%mod+mod)%mod<<endl;
	return 0;
}

II. CF622F The Sum of the k-th Powers

經典題。一個結論是自然數 \(k\) 次方和是 \(k+1\) 次多項式,那麼只需要帶 \((i,i^k)\ (i\in [0,k+1])\) 插值即可。注意到當 \(i=0\) 時對答案無貢獻,所以在插值的時候可以跳過(但是預處理 \(p,s\) 的時候仍應考慮 \(i=0\)\(n-i\) 的這個 \(n\))。

時間複雜度 \(\mathcal{O}(k\log k)\),使用線性篩篩 \(i^k\) 可以做到線性。

#include <bits/stdc++.h>
using namespace std;

#define ll long long

const ll mod=1e9+7;
const int N=1e6+5;

ll ksm(ll a,ll b){
	ll s=1;
	while(b){
		if(b&1)s=s*a%mod;
		a=a*a%mod,b>>=1;
	} return s;
} ll inv(ll x){return ksm(x,mod-2);}

ll n,k,ans;
ll p[N],s[N],fc[N];
int main(){
	cin>>n>>k,s[k+2]=1;
	for(int i=0;i<=k+1;i++)fc[i]=i?fc[i-1]*i%mod:1,p[i]=i?p[i-1]*(n-i)%mod:n;
	for(int i=k+1;~i;i--)s[i]=i==k+1?n-i:s[i+1]*(n-i)%mod;
	for(ll i=1,res=0;i<=k+1;i++){
		res=(res+ksm(i,k))%mod;
		ans=(ans+p[i-1]*s[i+1]%mod*res%mod*ksm(fc[i]*((k-i)&1?1:-1)*fc[k+1-i]%mod,mod-2))%mod;
	} cout<<(ans%mod+mod)%mod<<endl;
	return 0;
}

2. FFT

對於一個 \(n\) 次多項式 \(F(x)=a_0+a_1x+a_2x^2+\cdots+a_nx^n\),我們可以用 \(n+1\) 個點值 \((x_k,y_k)\) 唯一確定該多項式。即 \(y_k=\sum_{i=0}^na_ix_k^i\)(注意下文的 \(n\) 表示不小於多項式次數 \(+1\) 的最小的 \(2\) 的冪)。設 \(A=F\times G\),注意到 \(A(x)=F(x)G(x)\),其中 \(x\) 是一個確定的值。因此,我們只需要將一個多項式快速(\(n\log n\))轉點值,再快速轉成係數表示,就可以做到時間複雜度 \(n\log n\) 的多項式乘法。

點值的取法很有講究,高明的方法能夠極大化地減小時間複雜度。這裡我們採用 \(n\) 次單位根 \(\omega_n\),並主要利用以下性質:

  • \(\omega_n^k=\omega_{2n}^{2k}\)
  • \(\omega_n^k=-\omega_n^{k+n/2}\)
  • \(\omega_n^n=1\)
  • \(\omega_n=\cos(\dfrac{2\pi}{n})+i\sin(\dfrac{2\pi}{n})\),從而計算 \(n\) 次單位根。

\(n=4\) 時,\(F(x)=a_0+a_1x+a_2x^2+a_3x^3=(a_0+a_2x^2)+x(a_1+a_3x^2)\)。設 \(L(x)=a_0+a_2x\)\(R(x)=a_1+a_3x\),那麼 \(F(x)=L(x^2)+xR(x^2)\)。將單位根 \(\omega_n\) 帶入,那麼當 \(k<\dfrac{n}{2}\) 時,\(F(\omega_n^k)=L(\omega_{n/2}^k)+\omega_{n}^kR(\omega_{n/2}^k)\)\(F(\omega_n^{k+n/2})=L(\omega_{n/2}^k)-\omega_n^kR(\omega_{n/2}^k)\)。注意到兩式只有一個正負號不同。因此,如果我們已經知道了 \(L,R\)\(\omega_{n/2}^i,\ i\in[0,\dfrac{n}{2})\) 處的取值,那麼我們就可以線上性時間內求出 \(F\)\(\omega_n^i, \ i\in[0,n)\) 處的取值。考慮遞迴樹的每一層都是線性的,因此總複雜度為 \(n\log n\)

遞迴處理很慢,於是我們使用迭代:考慮係數 \(a_i\)\(L\) 分治為 \(0\),向 \(R\) 分治為 \(1\),那麼顯然 \(a_i\) 最終形成的 “分治序列” 形成的二進位制數倒過來就是 \(i\) 的二進位制表示。考慮求出 \(r_i\) 表示 \(i\) 二進位制翻轉後得到的數。假設 \(r_0,r_1,\cdots,r_{i-1}\) 都已經求出 ,那麼有 \(r_i=\lfloor\dfrac{r_{\lfloor\dfrac{i}{2}\rfloor}}{2}\rfloor+\dfrac{n}{2}\times(i\bmod 2)\)。左邊是 \(i\) 不考慮最低位(即假設其為 \(0\))時二進位制翻轉得到的數,然後再考慮 \(i\) 的最低位的影響即可。然後對於每一對無序對 \((i,r_i)\),將 \(a_i\)\(a_{r_i}\) 交換,那麼最終的分治樹的形態類似線段樹,直接從最底層向上迭代即可。該操作被稱為蝴蝶迭代


卡常技巧

  • 不要寫建構函式。
  • 過載加減乘運算子放到 struct 裡面。

3. NTT

由於複數運算很慢,而通常情況下我們是在模意義下進行多項式運算,所以當模數取一些特殊值時,我們可以用 \(\mathbb{Z}\) 中的數 \(g\) 代替單位根的複數運算。具體地,\(g\) 需要是模 \(p\) 的原根。

鴿著。

4. FFT 優化字串匹配

Trick 1:翻轉一個序列常常可以使關於它的某些計算變成卷積的形式。

對於一個文字串 \(s\) 與匹配串 \(t\)(下標從 \(0\) 開始),設它們的長度分別為 \(n\)\(m\)。稱它們在位置 \(p\) 匹配,當且僅當對於任意 \(i\in[0,m)\),有 \(s_{p-m+i+1}=t_i\)。不難發現它的充分條件為 \(\sum_{i=0}^{m-1}(s_{p-m+i+1}-t_i)^2=0\)。展開,得到 \(\sum_{i=0}^{m-1}(s^2_{p-m+i+1}+t^2_{i}-2s_{p-m+i+1}t_i)=0\)。注意到前面兩項容易預處理得到,但後面一項同時關於兩個字串,比較麻煩。又發現它的形式類似卷積,但又不是卷積:翻轉字串 \(t\) 即可。因此柿子變為 \(\sum_{i=0}^{m-1}(s^2_{p-m+i+1}+t_{m-i-1}^2-2s_{p-m+i+1}t_{m-i-1})\),後面一項即 \(2\sum_{0\leq i<m,i+j=p}s_jt_i\),FFT 計算即可。

對於有萬用字元的字串,我們不妨將該位置上的值設為 \(0\),然後乘到上面的柿子裡去,即 \(\sum_{i=0}^{m-1}(s_{p-m+i+1}-t_i)^2s_{p-m+i+1}t_i\)。化簡得到 \(\sum_{0\leq i<m,i+j=p}(s^3_jt_i-2s^2_jt^2_i+s_jt^3_i)\),做 6 次 DFT + 1 次 IDFT 即可。

Trick 2:一般情況下,上述方法足夠應付多數題目。但是如果萬惡的出題人卡了 FFT 精度就涼涼了。因此,為了保險,應儘量使用 NTT。但是 NTT 也有一個致命問題:如果計算出來的值剛好是 \(998244353\) 的倍數,那麼就會在不匹配的地方判定匹配。看起來好像沒有解決辦法了?非也。在 P4173 殘缺的字串 這題的討論區 https://www.luogu.com.cn/discuss/show/303076,我找到了一個高明的手段解決這個問題:注意到上述柿子最大值可達到 \(3\times m\times (|\mathbb{\Sigma}|-1)^4\),如果 \(m\)\(5\times 10^5\),字符集取大小 \(26\),那麼數量級為 \(1.5\times 10^6\times 25^4\approx 6\times 10^{11}\),顯然不太行。但是實際上我們沒必要將整個 \(s\)\(t\) 乘進去,因為我們關心的只是某一位是否是萬用字元,而具體這一位是什麼並不重要。因此,我們令 \(S_i=[s_i\neq\texttt{*}]\)\(T_i=[t_i\neq\texttt{*}]\),只需要將 \(S,T\) 而非 \(s,t\) 乘入即可。 這時最大值僅為 \(3\times 5\times 10^5\times 25^2=9.375\times 10^8<998244353\),有驚無險地保證了正確性。當然,不同題目還應根據 \(m\) 和字符集大小的不同取值具體分析正確性。

也許看了上述分析的你認為:既然將 \(S,T\) 乘進去,那麼不如直接將是萬用字元的位置當做 \(0\) 來計算 \(\sum_{i=0}^{m-1}(s_{p-m+i+1}-t_i)^2\) 不就好了?非也。因為這樣在計算只關於某一個字串的項時,考慮不到另外一個字串的對應位置是否是萬用字元。因此,6 次 DFT 是逃不掉的。最終的柿子即為 \(\sum_{0\leq i<m,i+j=p}((s_j^2S_j)\times T_i-2(s_jS_j)\times (t_iT_i)+S_j\times (t^2_iT_i))\)​。

例題 & 程式碼可以看 II.

5. FFT & NTT 例題

I. P3803 【模板】多項式乘法(FFT)

FFT:

#include <bits/stdc++.h>
using namespace std;

#define ld double

const int N=1<<21;
const ld Pi=acos(-1);

struct com{
	ld x,y;
	com operator + (com b){return (com){x+b.x,y+b.y};}
	com operator - (com b){return (com){x-b.x,y-b.y};}
	com operator * (com b){return (com){x*b.x-y*b.y,x*b.y+y*b.x};}
}a[N],b[N];

int n,m,lim=1,bit,r[N];
void FFT(com *a,int tp){
	for(int i=0;i<lim;i++)if(i<r[i])swap(a[i],a[r[i]]);
	for(int l=1;l<lim;l<<=1){
		com wn={cos(Pi/l),tp*sin(Pi/l)};
		for(int j=0;j<lim;j+=(l<<1)){
			com w={1,0},x,y;
			for(int k=0;k<l;k++,w=w*wn)
				x=a[j+k],y=w*a[j+k+l],a[j+k]=x+y,a[j+k+l]=x-y;
		}
	}
}
int main(){
	cin>>n>>m;
	for(int i=0;i<=n;i++)scanf("%lf",&a[i].x);
	for(int i=0;i<=m;i++)scanf("%lf",&b[i].x);
	while(lim<=n+m)lim<<=1,bit++;
	for(int i=1;i<lim;i++)r[i]=(r[i>>1]>>1)|((i&1)<<bit-1);
	FFT(a,1),FFT(b,1);
	for(int i=0;i<lim;i++)a[i]=a[i]*b[i];
	FFT(a,-1);
	for(int i=0;i<=n+m;i++)cout<<(int)(a[i].x/lim+0.5)<<" ";
	return 0;
}

NTT:

#include <bits/stdc++.h>
using namespace std;

#define ll long long
#define ull unsigned long long
#define gc getchar()

inline int read(){
	int x=0; char s=gc;
	while(!isdigit(s))s=gc;
	while(isdigit(s))x=s-'0',s=gc;
	return x;
}

const ll mod=998244353;
const int N=1<<21;

ll ksm(ll a,ll b){
	ll s=1;
	while(b){
		if(b&1)s=s*a%mod;
		a=a*a%mod,b>>=1;
	} return s;
} ll inv(ll x){return ksm(x,mod-2);}

const int G=3;
const int ivG=inv(3);

ll tr[N],lim=1,l;
ll n,m,f[N],g[N];
void NTT(ll *a,bool tp){
	static ull f[N],w[N]; w[0]=1;
	for(int i=0;i<lim;i++)f[i]=a[tr[i]];
	for(int l=1;l<lim;l<<=1){
		ll wn=ksm(tp?G:ivG,(mod-1)/(l+l));
		for(int i=1;i<l;i++)w[i]=w[i-1]*wn%mod;
		for(int i=0;i<lim;i+=l<<1){
			for(int j=0;j<l;j++){
				int y=w[j]*f[i|j|l]%mod;
				f[i|j|l]=f[i|j]+mod-y,f[i|j]+=y;
			}
		} if(l==(1<<17))for(int i=0;i<lim;i++)f[i]%=mod;
	}
	if(!tp){
		ll iv=inv(lim);
		for(int i=0;i<lim;i++)a[i]=f[i]%mod*iv%mod;
	} else for(int i=0;i<lim;i++)a[i]=f[i]%mod;
}
int main(){
	cin>>n>>m;
	for(int i=0;i<=n;i++)f[i]=read();
	for(int i=0;i<=m;i++)g[i]=read();
	while(lim<=n+m)lim<<=1,l++;
	for(int i=1;i<lim;i++)tr[i]=(tr[i>>1]>>1)|((i&1)<<l-1);
	NTT(f,1),NTT(g,1);
	for(int i=0;i<lim;i++)f[i]=f[i]*g[i]%mod;
	NTT(f,0);
	for(int i=0;i<=n+m;i++)printf("%lld ",f[i]);
}

II. P4173 殘缺的字串

經典字串匹配題。

#include <bits/stdc++.h>
using namespace std;

typedef double db;
typedef long long ll;
typedef long double ld;
typedef unsigned long long ull;

#define gc getchar()
#define pb push_back
#define mem(x,v,n) memset(x,v,sizeof(int)*n)
#define cpy(x,y,n) memcpy(x,y,sizeof(int)*n)

const ld Pi=acos(-1);
const ll mod=998244353;

inline int read(){
	int x=0; char s=gc;
	while(!isdigit(s))s=gc;
	while(isdigit(s))x=x*10+s-'0',s=gc;
	return x;
}

ll ksm(ll a,ll b){
	ll s=1;
	while(b){
		if(b&1)s=s*a%mod;
		a=a*a%mod,b>>=1;
	}
	return s;
}
ll inv(ll x){return ksm(x,mod-2);}

const int N=1<<19;
const ll G=3;
const ll ivG=inv(3);

int r[N],pren;
void pre(int n){
	if(n==pren)return;
	for(int i=1;i<n;i++)r[i]=(r[i>>1]>>1)|(i&1?n>>1:0);
}
void NTT(int *g,int n,bool op){
	pre(n);
	static ull f[N],w[N]; w[0]=1;
	for(int i=0;i<n;i++)f[i]=g[r[i]];
	for(int l=1;l<n;l<<=1){
		ull wn=ksm(op?G:ivG,(mod-1)/(l+l));
		for(int i=1;i<l;i++)w[i]=w[i-1]*wn%mod;
		for(int i=0;i<n;i+=l<<1)
			for(int j=0;j<l;j++){
				int t=w[j]*f[i|j|l]%mod;
				f[i|j|l]=f[i|j]+mod-t,f[i|j]+=t;
			}
		if(l==(1<<16))for(int i=0;i<n;i++)f[i]%=mod;
	}
	if(op)for(int i=0;i<n;i++)g[i]=f[i]%mod;
	else{
		ll iv=inv(n);
		for(int i=0;i<n;i++)g[i]=f[i]%mod*iv%mod;
	}
}

int n,m,lim=1,a[N],b[N],res[N];
int ans[N],cnt;
string A,B;
int main(){
	cin>>n>>m>>A>>B,reverse(A.begin(),A.end());
	while(lim<m)lim<<=1;
	
	for(int i=0;i<n;i++)if(A[i]!='*')a[i]=(A[i]-'a')*(A[i]-'a');
	for(int i=0;i<m;i++)if(B[i]!='*')b[i]=1;
	NTT(a,lim,1),NTT(b,lim,1);
	for(int i=0;i<lim;i++)res[i]=1ll*a[i]*b[i]%mod;
	
	mem(a,0,lim),mem(b,0,lim);
	for(int i=0;i<n;i++)if(A[i]!='*')a[i]=A[i]-'a';
	for(int i=0;i<m;i++)if(B[i]!='*')b[i]=B[i]-'a';
	NTT(a,lim,1),NTT(b,lim,1);
	for(int i=0;i<lim;i++)res[i]=(res[i]-2ll*a[i]*b[i]%mod+mod)%mod;
	
	mem(a,0,lim),mem(b,0,lim);
	for(int i=0;i<n;i++)if(A[i]!='*')a[i]=1;
	for(int i=0;i<m;i++)if(B[i]!='*')b[i]=(B[i]-'a')*(B[i]-'a');
	NTT(a,lim,1),NTT(b,lim,1);
	for(int i=0;i<lim;i++)res[i]=(res[i]+1ll*a[i]*b[i])%mod;
	
	NTT(res,lim,0);
	for(int i=n-1;i<m;i++)if(!res[i])ans[++cnt]=i-n+2;
	cout<<cnt<<endl;
	for(int i=1;i<=cnt;i++)cout<<ans[i]<<" ";
	
	return 0;
}