1. 程式人生 > 實用技巧 >題解 下降冪多項式乘法

題解 下降冪多項式乘法

題目傳送門

題目大意

給出兩個次數分別為\(n,m\)的下降冪多項式,表示為:

\[F(x)=\sum_{i=0}^{n} a_i x^{\underline{i}},G(x)=\sum_{i=0}^{n} b_i x^{\underline {i}} \]

求出一個下降冪多項式\(H(x)\)使得對於\(\forall x,H(x)=F(x)G(x)\)

\(n,m\le 10^5\)

思路

有一個思路特別\(\texttt{naive}\)但是碼起來會死人的方法,就是兩個下降冪多項式轉換成普通多項式,然後乘起來再轉換成下降冪多項式。時間複雜度\(\Theta(n\log ^2n)\),常數估計也賊大,估計是通過不了這道題的。

以下內容借鑑了\(\texttt{command-block}\)的題解。

我們發現我們其實用與多項式乘法相同的做法,我們把下降冪多項式轉換成點值表示法然後乘起來再轉換成下降冪多項式。問題就是我們如何把一個下降冪多項式轉換成點值表示法。

我們假設要轉換的下降冪多項式為\(F\),我們假設點值的\(\texttt{EGF}\)\(F_1\)

我們發現如果\(F(x)=x^{\underline {n}}\),那我們可以得到\(F_1(x)\)為:

\[\sum_{i=0}^{\infty} \frac{i^{\underline n}x^i}{i!}=\sum_{i=0}^{\infty} \frac{x^i}{(n-i)!}=x^n\sum_{i=0}^{\infty} \frac{x^i}{i!}=x^ne^x \]

而我們根據定義可以得到:

\[F_1(x)=\sum_{i=0}^{\infty} F(i) \frac{x^i}{i!} \]

又因為\(F(x)=\sum_{i=0}^{\infty} F[i]x^{\underline {i}}\),所以我們可以得到:

\[F_1(x)=\sum_{i=0}^{\infty} \frac{x^i}{i!} \sum_{j=0}^{\infty} F[j]i^{\underline {j}} \]

\[=\sum_{j=0}^{\infty} F[j]\sum_{i=0}^{\infty} \frac{i^{\underline {j}}x^i}{i!} \]

\[=\sum_{j=0}^{\infty} F[j]x^je^x \]

\[=e^x\sum_{j=0}^{\infty} F[j]x^j \]

於是,我們就可以直接普通多項式乘法\(\Theta(n\log n)\)求出一個下降冪多項式的點值的\(\texttt{EGF}\),求出來之後直接乘起來。那要從點值\(\texttt{EGF}\)轉換成下降冪多項式很顯然直接乘上\(e^{-x}\),這個東西用泰勒展開(顯然)就是:

\[\sum_{i=0}^{\infty} \frac{(-1)^ix^i}{i!} \]

\(\texttt{Code}\)

#pragma GCC optimize("Ofast")
#pragma GCC optimize("inline", "no-stack-protector", "unroll-loops")
#pragma GCC diagnostic error "-fwhole-program"
#pragma GCC diagnostic error "-fcse-skip-blocks"
#pragma GCC diagnostic error "-funsafe-loop-optimizations"

#include <bits/stdc++.h>
using namespace std;

#define SZ(x) ((int)x.size())
#define Int register int
#define mod 998244353
#define MAXN 1600005

int mul (int a,int b){return 1ll * a * b % mod;}
int dec (int a,int b){return a >= b ? a - b : a + mod - b;}
int add (int a,int b){return a + b >= mod ? a + b - mod : a + b;}
int qkpow (int a,int k){
	int res = 1;for (;k;k >>= 1,a = 1ll * a * a % mod) if (k & 1) res = 1ll * res * a % mod;
	return res;
}
int inv (int x){return qkpow (x,mod - 2);}

typedef vector <int> poly;

int up,w[MAXN],rev[MAXN];

void init_ntt (){
	int lim = 1 << 19;
	for (Int i = 0;i < lim;++ i) rev[i] = (rev[i >> 1] >> 1) | ((i & 1) << 18);
	int Wn = qkpow (3,(mod - 1) / lim);w[lim >> 1] = 1;
	for (Int i = lim / 2 + 1;i < lim;++ i) w[i] = mul (w[i - 1],Wn);
	for (Int i = lim / 2 - 1;i;-- i) w[i] = w[i << 1];
}

void ntt (poly &a,int lim,int type){
#define G 3
#define Gi 332748118
	static unsigned long long d[MAXN];
	for (Int i = 0,z = 19 - __builtin_ctz(lim);i < lim;++ i) d[rev[i] >> z] = a[i];
	for (Int i = 1;i < lim;i <<= 1)
		for (Int j = 0;j < lim;j += i << 1)
			for (Int k = 0;k < i;++ k){
				int x = 1ll * d[i + j + k] * w[i + k] % mod;
				d[i + j + k] = d[j + k] + mod - x,d[j + k] += x;
			}
	for (Int i = 0;i < lim;++ i) a[i] = d[i] % mod;
	if (type == -1){
		reverse (a.begin() + 1,a.begin() + lim);
		for (Int i = 0,Inv = inv (lim);i < lim;++ i) a[i] = mul (a[i],Inv);
	}
#undef G
#undef Gi 
}

poly operator * (poly a,poly b){
	int d = SZ (a) + SZ (b) - 1,lim = 1;while (lim < d) lim <<= 1;
	a.resize (lim),b.resize (lim);
	ntt (a,lim,1),ntt (b,lim,1);
	for (Int i = 0;i < lim;++ i) a[i] = mul (a[i],b[i]);
	ntt (a,lim,-1),a.resize (up + 1);
	return a;
}

template <typename T> inline void read (T &t){t = 0;char c = getchar();int f = 1;while (c < '0' || c > '9'){if (c == '-') f = -f;c = getchar();}while (c >= '0' && c <= '9'){t = (t << 3) + (t << 1) + c - '0';c = getchar();} t *= f;}
template <typename T,typename ... Args> inline void read (T &t,Args&... args){read (t);read (args...);}
template <typename T> inline void write (T x){if (x < 0){x = -x;putchar ('-');}if (x > 9) write (x / 10);putchar (x % 10 + '0');}

poly I,Iv,A,B;
int n,m,fac[MAXN],ifac[MAXN];

signed main(){
	init_ntt (),read (n,m);up = n + m;A.resize (n + 1),B.resize (m + 1);
	for (Int i = 0;i <= n;++ i) read (A[i]);
	for (Int i = 0;i <= m;++ i) read (B[i]);
	fac[0] = 1;for (Int i = 1;i <= up;++ i) fac[i] = mul (fac[i - 1],i);
	ifac[up] = inv (fac[up]);for (Int i = up;i;-- i) ifac[i - 1] = mul (ifac[i],i);
	I.resize (up + 1);for (Int i = 0;i <= up;++ i) I[i] = ifac[i];
	A = A * I,B = B * I;
	for (Int i = 0;i <= up;++ i) A[i] = mul (mul (A[i],B[i]),fac[i]);
	Iv.resize (up + 1);for (Int i = 0;i <= up;++ i) Iv[i] = (i & 1 ? mod - ifac[i] : ifac[i]);
	A = A * Iv;
	for (Int i = 0;i <= up;++ i) write (A[i]),putchar (' ');
	putchar ('\n');
	return 0;
}

P.S.

本來沒有卡過去,結果直接一波卡常就卡成\(\text {rank 3}\)了(至少在這個時候還是\(\text{rank 3}\))。。。(霧