1. 程式人生 > 其它 >【luogu P4725】【模板】多項式對數函式(多項式 ln)(NTT)

【luogu P4725】【模板】多項式對數函式(多項式 ln)(NTT)

【模板】多項式對數函式(多項式 ln)

題目連結:luogu P4725

題目大意

給你一個 n-1 次多項式,要你求一個 mod x^n 下的多項式使得它是給出多項式的 ln。

思路

\(G(x)=F(A(x)),F(x)=\ln x\)
考慮對兩邊同時求導:
\(G'(x)=F'(A(x))A'(x)\)(這個是複合函式求導公式)
然後根據 \((\ln x)'=\dfrac{1}{x}\)
\(G'(x)=\dfrac{A'(x)}{A(x)}\)
\(G(x)={\large\int}\dfrac{A'(x)}{A(x)}dx\)

然後按著這個搞就可以了。

程式碼

#include<cstdio>
#include<cstring>
#include<algorithm>
#define ll long long
#define mo 998244353
#define clr(f, n) memset(f, 0, (n) * sizeof(int))
#define cpy(f, g, n) memcpy(f, g, (n) * sizeof(int))

using namespace std;

const int N = 100000 * 8 + 1;
int n, f[N], an[N], inv[N], G = 3, Gv;

int jia(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
int jian(int x, int y) {return x < y ? x - y + mo : x - y;}
int cheng(int x, int y) {return 1ll * x * y % mo;}
int ksm(int x, int y) {int re = 1; while (y) {if (y & 1) re = cheng(re, x); x = cheng(x, x); y >>= 1;} return re;}

void Init() {
	Gv = ksm(G, mo - 2);
	inv[0] = inv[1] = 1; for (int i = 2; i < N; i++) inv[i] = cheng(inv[mo % i], mo - mo / i);
}

void get_an(int limit, int l_size) {
	for (int i = 0; i < limit; i++)
		an[i] = (an[i >> 1] >> 1) | ((i & 1) << (l_size - 1)); 
}

void NTT(int *f, int limit, int op) {
	for (int i = 0; i < limit; i++) if (an[i] < i) swap(f[an[i]], f[i]);
	for (int mid = 1; mid < limit; mid <<= 1) {
		int Wn = ksm(op == 1 ? G : Gv, (mo - 1) / (mid << 1));
		for (int R = (mid << 1), j = 0; j < limit; j += R) {
			int w = 1;
			for (int k = 0; k < mid; k++, w = cheng(w, Wn)) {
				int x = f[j | k], y = cheng(w, f[j | mid | k]);
				f[j | k] = jia(x, y); f[j | mid | k] = jian(x, y);
			}
		}
	}
	if (op == -1) {
		int limv = ksm(limit, mo - 2);
		for (int i = 0; i < limit; i++) f[i] = cheng(f[i], limv);
	}
}

void px(int *f, int *g, int limit) {
	for (int i = 0; i < limit; i++)
		f[i] = cheng(f[i], g[i]);
}

void times(int *f, int *g, int n, int m) {
	static int tmp[N];
	int limit = 1, l_size = 0; while (limit <= n + n) limit <<= 1, l_size++;
	cpy(tmp, g, n); clr(tmp + n, limit - n);
	get_an(limit, l_size);
	NTT(f, limit, 1); NTT(tmp, limit, 1);
	px(f, tmp, limit); NTT(f, limit, -1);
	clr(f + m, limit - m); clr(tmp, limit);
}

void invp(int *f, int n) {
	static int w[N], r[N], tmp[N];
	w[0] = ksm(f[0], mo - 2);
	int limit = 1, l_size = 0;
	for (int len = 2; (len >> 1) <= n; len <<= 1) {
		limit = len; l_size++; get_an(limit, l_size);
		cpy(r, w, len >> 1);
		cpy(tmp, f, limit); NTT(tmp, limit, 1);
		NTT(r, limit, 1); px(r, tmp, limit);
		NTT(r, limit, -1); clr(r, limit >> 1);
		cpy(tmp, w, len); NTT(tmp, limit, 1);
		NTT(r, limit, 1); px(r, tmp, limit);
		NTT(r, limit, -1);
		for (int i = (len >> 1); i < len; i++)
			w[i] = jian(cheng(w[i], 2), r[i]);
	}
	cpy(f, w, n); clr(w, n); clr(r, n); clr(tmp, n);
}

void dao(int *f, int n) {
	for (int i = 1; i < n; i++)
		f[i - 1] = cheng(f[i], i);
	f[n - 1] = 0;
}

void jifen(int *f, int n) {
	for (int i = n; i >= 1; i--)
		f[i] = cheng(f[i - 1], inv[i]);
	f[0] = 0;
}

void lnp(int *f, int n) {
	static int g[N];
	cpy(g, f, n); dao(g, n);
	invp(f, n); times(f, g, n, n);
	jifen(f, n - 1); clr(g, n);
}

int main() {
	Init();
	
	scanf("%d", &n);
	for (int i = 0; i < n; i++) scanf("%d", &f[i]);
	
	lnp(f, n);
	
	for (int i = 0; i < n; i++) printf("%d ", f[i]);
	
	return 0;
}