1. 程式人生 > 其它 >【luogu P4717】【模板】快速莫比烏斯/沃爾什變換 (FMT/FWT)(數學)

【luogu P4717】【模板】快速莫比烏斯/沃爾什變換 (FMT/FWT)(數學)

【模板】快速莫比烏斯/沃爾什變換 (FMT/FWT)

題目連結:luogu P4717

題目大意

給你兩個長度為 2^n 的陣列 A,B,設陣列 C:
C[i]=sum{j⊕k=i}A[j]B[k]
分別當 ⊕ 是 or,and,xor 三種運算子時求出陣列 C。

思路

嗯這裡只會講簡單的 FWT,而且不說 FMT。
(畢竟聽說 FMT 被 FWT 完爆所以就沒學awa)
如果要看複雜的 FWT 可以看這位大佬的部落格

其實感覺這個 FWT 的過程其實就是類比了 FFT 這一類的方式。

\(c_i=\sum\limits_{i=j⊕ k}a_jb_k\)

那我們先看就題目的三種運算。

\(c_i=\sum\limits_{i=j|k}a_jb_k\)


考慮搞點性質,發現 \(j|i=i,k|i=i\rightarrow(k|j)|i=i\)

然後我們考慮構造 \(fwt[a/b/c]_i\),使得 \(fwt[a]_i*fwt[b]_i=fwt[c]_i\)
然後可以找到 \(fwt[a]_i=\sum\limits_{j|i=i}a_j\)
\(fwt[a]_i*fwt[b]_i=(\sum\limits_{j|i=i}a_j)(\sum\limits_{k|i=i}b_k)\)
\(=\sum\limits_{j|i=i}\sum\limits_{k|i=i}a_jb_k=\sum\limits_{(j|k)|i=i}a_jb_k=fwt[c]_i\)

然後接著就是考慮 \(fwt[a]_i\) 這些怎麼快速求。
考慮先看下標二進位制的最高位,然後讓 \(a0\) 表示它下標最高位為 \(0\) 的那部分序列,\(a1\) 表示為 \(1\) 的那部分。
然後根據或的性質就有:
\(fwt[a]_i=(fwt[a0]_i,fwt[a0]+fwt[a1])\)
(就是這兩個部分拼在一起,加號是每個位置加起來)

然後至於從 \(fwt[a]_i\) 求會 \(a\) 可以根據上面的反過來得到:
\(a=(a0,a1-a0)\)

然後你會發現它的形式很像 FFT/NTT 裡面的,然後你就類似著搞就可以了。

跟著或的道理,你會發現是:
\(fwt[a]_i=(fwt[a0]_i+fwt[a1],fwt[a1])\)


\(a=(a0-a1,a1)\)

異或

這個就不一樣的,考慮再找性質:
如果我們讓 \(x⊕y=count1(x\&y)\bmod 2\)\(count(1)\) 是二進位制中 \(1\) 的個數)
然後你會發現有 \((i⊕j)xor(i⊕k)=i⊕(j\ xor\ k)\)(你分類討論一下會發現確實是這樣的)

然後你就構造,可以得到:
\(fwt[a]_i=\sum\limits_{i⊕j=0}a_j-\sum\limits_{i⊕j=1}a_j\)

然後:
\(fwt[a]=(fwt[a0]+fwt[a1],fwt[a0]-fwt[a1])\)
\(a=(\dfrac{a0+a1}{2},\dfrac{a0-a1}{2})\)

構造怎麼弄的?

相信你看了這三個的構造 \(fwt[a]_i\) 陣列的結果,會有那麼一絲絲的疑惑:為啥能想到這個構造方法。

那我們就簡單講講如何構造 FWT 中的這個陣列。
首先我們用未知數表示:
\(fwt[a]_i=\sum\limits_{j=0}^{n-1}s(i,j)a_j\)

然後列出要求的式子:
\(fwt[a]_i*fwt[b]_i=fwt[c]_i\)
\(\sum\limits_{j=0}^{n-1}s(i,j)a_j\sum\limits_{k=0}^{n-1}s(i,k)b_k=\sum\limits_{p=0}^{n-1}s(i,p)c_p\)
\(\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1}s(i,j)s(i,k)a_jb_k=\sum\limits_{p=0}^{n-1}s(i,p)c_p\)
然後再有 \(a*b=c\)
\(c_p=\sum\limits_{j\oplus k=p}a_jb_k\)
\(\sum\limits_{p=0}^{n-1}s(i,p)c_p=\sum\limits_{p=0}^{n-1}s(i,p)\sum\limits_{j\oplus k=p}a_jb_k\)
\(\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1}s(i,j)s(i,k)a_jb_k=\sum\limits_{p=0}^{n-1}s(i,p)\sum\limits_{j\oplus k=p}a_jb_k=\sum\limits_{p=0}^{n-1}\sum\limits_{j\oplus k=p}a_jb_ks(i,j\oplus k)=\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1}a_jb_ks(i,j\oplus k)\)
\(\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1}s(i,j)s(i,k)a_jb_k=\sum\limits_{j=0}^{n-1}\sum\limits_{k=0}^{n-1}a_jb_ks(i,j\oplus k)\)

所以我們就需要讓 \(s(i,j)s(i,k)=s(i,j\oplus k)\)
接著就要用到 FWT 最特別的地方了:它是解決有關位運算的問題的。
也就是說它二進位制每一位是互相獨立的!

所以假設我們已經求出來對於一位的 \(s([0,1],[0,1])\),那我們就可以構造出所有的 \(s\)
\(a\) 二進位制的每一位是 \(a_0,a_1,a_2,...\),那 \(s(i,j)=s(i_0,j_0)s(i_1,j_1)s(i_2,j_2)...\),就是每位的乘起來。
那麼對於每一位:
\(s(i,j)s(i,k)=s(i,j\oplus k)\Leftrightarrow s(i_t,j_t)s(i_t,k_t)=s(i_t,j_t\oplus k_t)\)

那這個我們就可以每一位通過 \(0,1\) 的分類討論求解,得到符合的 \(s\)
而且在 or,and,xor 中它們的符合的 \(s\) 其實是有兩種的,那我們選隨便一個都可以用了。
比如 or 的是有這兩種:
\(\begin{bmatrix}1&1\\1&0\end{bmatrix}\)\(\begin{bmatrix}1 & 0\\ 1 & 1\end{bmatrix}\)

那構造就構造好啦!

\(fwt[a]\) 的求法感覺還是有點迷

那我們繼續用上面的來:
\(fwt[a]_i=\sum\limits_{j=0}^{n-1}s(i,j)a_j\)

繼續折半:\(\sum\limits_{j=0}^{n/2-1}s(i,j)a_j+\sum\limits_{j=n/2}^{n-1}s(i,j)a_j\)
然後也想前面那樣 \(i'\)\(i\) 去掉二進位制首位的數。
\(\sum\limits_{j=0}^{n/2-1}s(i_0,j_0)s(i',j')a_j+\sum\limits_{j=n/2}^{n-1}s(i_0,j_0)s(i',j')a_j\)
\(\sum\limits_{j=0}^{n/2-1}s(i_0,0)s(i',j')a_j+\sum\limits_{j=n/2}^{n-1}s(i_0,1)s(i',j')a_j\)
\(c(i',j')\) 就是去掉首位的,規模自然減半。
\(i_0=0\)\(0\leqslant i<n/2\)\(fwt[a]_i=s(0,0)fwt(a_0)_i+s(0,1)fwt(a_1)_i\)
\(i_0=1\)\(n/2\leqslant <n\)\(fwt[a]_i=s(1,0)fwt(a_0)_i+s(1,1)fwt(a_1)_i\)

然後如果是 \(ifwt\)(就是從 \(fwt[a]\)\(a\))就是把 \(s\) 這個矩陣求逆。

擴充套件一下,如果不是位運算還有可能用的上嗎

其實是可以的,因為位運算你可以看做是 \(n\)\(01\) 向量做運算:
or 是取 \(\max\),and 是取 \(\min\),xor 是每一維相加的結果 \(\bmod\ 2\)

那我們可以擴充套件到 \([0,k)\),那我們要的 \(s\) 就是一個 \(k*k\) 的矩陣,然後暴力算是 \(k^{n+1}n\)
然後這些矩陣也可以快速算,\(\max\min\) 是高位字首和壓掉一個 \(k\)\(\mod k\) 的話列 \(s\) 可以用範德蒙德矩陣。
然後可以用 FTT 把一個 \(k\) 變成 \(\log k\)

r然而又因為單位根模的意義下可能不存在,所以你要通過再來一個 \(k\) 的複雜度以及一通神仙操作(別想了我不可能會的去看那位大佬的部落格吧)

程式碼

#include<cstdio>
#include<cstring>
#define mo 998244353
#define cpy(f, g, n) memcpy(f, g, sizeof(int) * (n))
#define clr(f, n) memset(f, 0, sizeof(int) * (n))

using namespace std;

const int N = (1 << 17);
int n, f[N], g[N], inv2, tmp[N];

int jia(int x, int y) {return x + y >= mo ? x + y - mo : x + y;}
int jian(int x, int y) {return x - y < 0 ? x - y + mo : x - y;}
int cheng(int x, int y) {return 1ll * x * y % mo;}

void px(int *f, int *g, int n) {
	for (int i = 0; i < n; i++)
		f[i] = cheng(f[i], g[i]);
}

void FWT_or(int *f, int n, int op) {
	for (int mid = 1; mid < n; mid <<= 1)
		for (int j = 0; j < n; j += (mid << 1))
			for (int k = 0; k < mid; k++) {
				int x = f[j | k], y = f[j | mid | k];
				f[j | k] = x; f[j | mid | k] = jia(cheng((op == 1) ? 1 : mo - 1, x), y);
			}
}

void FWT_and(int *f, int n, int op) {
	for (int mid = 1; mid < n; mid <<= 1)
		for (int j = 0; j < n; j += (mid << 1))
			for (int k = 0; k < mid; k++) {
				int x = f[j | k], y = f[j | mid | k];
				f[j | k] = jia(x, cheng((op == 1) ? 1 : mo - 1, y)); f[j | mid | k] = y;
			}
}

void FWT_xor(int *f, int n, int op) {
	for (int mid = 1; mid < n; mid <<= 1)
		for (int j = 0; j < n; j += (mid << 1))
			for (int k = 0; k < mid; k++) {
				int x = f[j | k], y = f[j | mid | k];
				f[j | k] = jia(x, y); f[j | mid | k] = jian(x, y);
				if (op == -1) f[j | k] = cheng(f[j | k], inv2), f[j | mid | k] = cheng(f[j | mid | k], inv2);
			}
}

void cheng_or(int *f, int *g, int n) {
	static int Tmp[N];
	cpy(Tmp, g, n);
	FWT_or(f, n, 1); FWT_or(Tmp, n, 1);
	px(f, Tmp, n); FWT_or(f, n, -1);
	clr(Tmp, n);
}

void cheng_and(int *f, int *g, int n) {
	static int tmp[N];
	cpy(tmp, g, n);
	FWT_and(f, n, 1); FWT_and(tmp, n, 1);
	px(f, tmp, n); FWT_and(f, n, -1);
	clr(tmp, n);
}

void cheng_xor(int *f, int *g, int n) {
	static int tmp[N];
	cpy(tmp, g, n);
	FWT_xor(f, n, 1); FWT_xor(tmp, n, 1);
	px(f, tmp, n); FWT_xor(f, n, -1);
	clr(tmp, n);
}

int main() {
	scanf("%d", &n); n = (1 << n);
	for (int i = 0; i < n; i++) scanf("%d", &f[i]);
	for (int i = 0; i < n; i++) scanf("%d", &g[i]);
	
	inv2 = (mo + 1) / 2;
	cpy(tmp, f, n); cheng_or(tmp, g, n); for (int i = 0; i < n; i++) printf("%d ", tmp[i]); puts("");
	cpy(tmp, f, n); cheng_and(tmp, g, n); for (int i = 0; i < n; i++) printf("%d ", tmp[i]); puts("");
	cpy(tmp, f, n); cheng_xor(tmp, g, n); for (int i = 0; i < n; i++) printf("%d ", tmp[i]); puts("");
	
	return 0;
}