1. 程式人生 > 實用技巧 >多項式的高階運算

多項式的高階運算

基礎:FFT與NTT

多項式求乘法逆元

【模板】多項式乘法逆

\(Code:\)

int n;
ll A[N], B[N], C[N], r[N];
ll limi, l;
inline ll quickpow(ll x, ll k)...
inline void ntt(ll *a, int type) {...//此處已經讓type = 1的乘inv了
void sol(int deg, ll *a, ll *b) {//b is 逆元
	if (deg == 1) {
		b[0] = quickpow(a[0], P - 2);
		return ;
	}
	sol((deg + 1) >> 1, a, b);
	
	limi = 1, l = 0;
	while (limi <= (deg << 1))	limi <<= 1, ++l;
	for (register int i = 1; i <= limi; ++i)
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (l - 1));
	for (register int i = 0; i < deg; ++i)	C[i] = a[i];//轉移到C防變化 
	for (register int i = deg; i < limi; ++i)	C[i] = 0;//多次清空更保險 
	ntt(b, 1); ntt(C, 1);
	for (register int i = 0; i < limi; ++i)//B = 2B' - AB'B' = B'(2 - AB')
		b[i] = ((2ll - b[i] * C[i]) % P + P) % P * b[i] % P;
	ntt(b, -1);
	for (register int i = deg; i <= limi; ++i)//多次清空更保險 
		b[i] = 0;
}
int main() {
	read(n);
	for (register int i = 0; i < n; ++i)	read(A[i]);
	sol(n, A, B);
	for (register int i = 0; i < n; ++i)
		printf("%lld ", B[i]);
	return 0;
}

這裡給一個簡化的板子,供複習:

inline void get_inv(ll *a, ll *b, int deg) {//deg:項數 
	if (deg == 1)	return b[0] = inv(a[0]), void() ; 
    
	get_inv(a, b, (deg + 1) >> 1);
	limi = 1; len = 0;
	while (limi <= (deg << 1)) limi <<= 1, len++;
	for (register int i = 0; i < limi; ++i)
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (len - 1));
        
	for (register int i = 0; i < deg; ++i)	c[i] = a[i], d[i] = b[i];
	for (register int i = deg; i < limi; ++i)	c[i] = d[i] = 0;
    
	ntt(c, 1), ntt(d, 1);
	for (register int i = 0; i < limi; ++i)	c[i] = c[i] * d[i] % P * d[i] % P;
	ntt(c, -1);
	for (register int i = 0; i < deg; ++i)	b[i] = ((2 * b[i] - c[i]) % P + P) % P;
}

多項式開根號

題意

  • 給A(x),其中a0 = 1,求B(x),使得B(x)^2 = A(x) (mod x^n)

思路簡析

與多項式求逆相同,由於一平方就mod x^m -> mod x^(2m),我們考慮遞迴求解。

表示式

同樣假設我們已經求出B(x)的一半b(x),那麼:

A * b = 1(mod x^m)
A * B = 1(mod x^m)
∴B - b = 0(mod x^m)
兩邊平方:
B^2 - 2 * B * b + b^2 = 0(mod x^(2*m))
據B^2 = A(mod x^(2*m)):
A - 2 * B * b + b^2 = 0(mod x^(2*m))
於是:
B = (A/b + b) / 2

配合多項式求逆解出B(x)。

遞迴邊界

模板題非常友善,告訴我們 a0 = 1,於是遞迴邊界為

if (deg == 1) {b[0] = 1; return ;}

如果題目沒有那麼友善,那麼我們或許可以多random幾個數 我們需要用二次剩餘之類的麻煩的東西,或者考慮換一種演算法。

淺談二次剩餘

Code:

void get_sqrt(ll *a, ll *b, int deg) {
	if (deg == 1) {b[0] = 1; return ;}
	get_sqrt(a, b, (deg + 1) >> 1);
  //get_len
	limi = 1, len = 0;
	while (limi <= (deg << 1))	limi <<= 1, len++;
	for (register int i = 0; i <= limi; ++i)
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (len - 1));
        
  //copy and multiply
	for (register int i = 0; i <= limi; ++i)	bn[i] = 0;//attention
	get_inv(b, bn, deg);
	for (register int i = 0; i < deg; ++i)	C[i] = a[i];
	for (register int i = deg; i <= limi; ++i)	C[i] = 0;
	ntt(C, 1); ntt(bn, 1);
	for (register int i = 0; i <= limi; ++i)
		C[i] = C[i] * bn[i] % P;
	ntt(C, -1);
	for (register int i = 0; i < deg; ++i)	b[i] = (C[i] + b[i]) * inv2 % P;
	for (register int i = deg; i <= limi; ++i)	b[i] = 0;
}

注意:

  1. 用陣列前一定注意清空。我也不知道為什麼,反正不清空就會出錯。估計是NTT的禍吧。

多項式除法

  • P4512 【模板】多項式除法

  • A(x) * B(x) + C(x) = D(x),給出D(x), A(x),求B(x),C(x)。(類似高精除)

  • m = deg(A) < 100,000, n = deg(D) <= 100,000,m <= n

思路簡析

我們發現神奇的事情:把A,B,C,D都翻轉過來,(把C加到D後面),等式仍然成立。並且還有一個好處:C轉過來後D的0~n-m項都不受C的影響,而B又肯定超不過n-m+1項。因此我們可以藉助反轉後的A,D陣列算出B陣列,然後什麼都好搞了。

Code:

int main() {
	read(n); read(m); n++; m++;//n,m變成項數
	for (register int i = 0; i < n; ++i)	read(D[i]), Dbp[i] = D[i];
	for (register int i = 0; i < m; ++i)	read(A[i]), Abp[i] = A[i];//backup
	Reverse(A, m);
	Reverse(D, n);
	for (register int i = n - m + 1; i < n; ++i)
		A[i] = D[i] = 0;
	get_inv(A, An, n - m + 1);
	mul(D, An, B, n - m + 1, n - m + 1);
    //D(n-m+1項) * An(n-m+1項) -> B
	Reverse(B, n - m + 1);
	mul(Abp, B, AB, m, n - m + 1);
	for (register int i = 0; i < m - 1; ++i)
		C[i] = (Dbp[i] - AB[i] + P) % P;
    ...
}

注意

  • 在算D*inv(A)時,保險起見,只保留D和A的0~n-m項,且對A,D做備份,算C時用。

  • 什麼時候用n-m,什麼時候用n-m+1,要分清楚!

  • 此時多項式變數名逐漸增多,注意區分,不要把An寫成A!

分治FFT

(其實我覺得對於我這個幾乎只寫NTT的人來說,叫分治NTT比較好)

【模板】分治 FFT

簡單說一下,分治FFT用到了CDQ分治的思想,但不用非得學CDQ分治,畢竟這個思想還是比較好理解的,之前也經常用到。簡而言之,這裡的CDQ分治思想就是:每次只考慮左半段對右半段的貢獻,先遞迴解決左半段,然後讓右半段加上左半段的貢獻,再遞迴解決右半段。這樣,一次次貢獻的加和就組成了每個位置的值。

將題意稍稍轉化,f[i] = g[0 ~ i]與f[0 ~ i]的卷積(多項式乘法)。剩下的推導部分可以通過手玩來推導。

最關鍵的兩條語句:

    for (int i = l;i <= mid; i++) a[i-l] = f[i];
    for (int i = 0;i < len; i++) b[i] = g[i];

\(Code:\) my code

(主要篇幅是NTT,其實重點就在於以上兩條語句,NTT只是工具)

多項式求ln

在學習微積分後,我再學ln,感覺舒適了很多。

求ln很簡單,兩邊求個到,用一下多項式求逆,再積分即可。

思維難度低,程式碼量大。

新增兩條背誦語句:

inline void dao(ll *a, ll *b, int n) {
	for (register int i = 1; i < n; ++i)	b[i - 1] = i * a[i] % P; b[n - 1] = 0;
}
inline void ji(ll *a, ll *b, int n) {
	for (register int i = 1; i < n; ++i)	b[i] = a[i - 1] * inv(i) % P; b[0] = 0;
}

其實也不是很難背,考慮 求導/積分 完後的b[i](或b[i - 1])是從哪裡轉移來的,就很容易理解了。

剩餘部分程式碼:

	get_inv(A, An, n); dao(A, Ad, n);
	
	limi = 1, len = 0;
	while (limi <= (n << 1))	limi <<= 1, len++;
	for (register int i = 0; i <= limi; ++i) 
		r[i] = (r[i >> 1] >> 1) | ((i & 1) << (len - 1));
	
	ntt(Ad, 1); ntt(An, 1);
	for (register int i = 0; i <= limi; ++i)	A[i] = Ad[i] * An[i] % P;
	ntt(A, -1);
	ji(A, C, n);
	for (register int i = 0; i < n; ++i)
		printf("%lld ", C[i] % P);