1. 程式人生 > 其它 >快速沃爾什變換學習筆記

快速沃爾什變換學習筆記

說好了在noip之前不學多項式演算法……結果就真香了

快速沃爾什變換

給定長度為 \(2^n\) 兩個序列 \(A,B\),設

\[C_g=\sum_{i\bigoplus j=g}A_i \times B_j \]

分別當 \(\bigoplus\)\(or,and,xor\) 時求出 \(C\)

\(n\leq 17\)


據說與 \(FFT\) 的核心思想相同,都是對陣列的變換。

對於原陣列A,B,在某種運算規則下,它們的結果很難求。但若是對該陣列進行變換得到新陣列 \(A'\)\(B'\),而他們在該運算規則下的計算很好求,這時就能得到 \(C'\)。最後再對該序列進行逆變換即可獲得答案 \(C\)


記該變換陣列為 \(FWT\),逆變換陣列為 \(IFWT\)

or 運算下

定義 \(FWT[A] = \sum_{i|j=i}A_j\)

根據定義可推得

\(FWT[C] =FWT[A] \times FWT[B]\)

證明:

\[FWT[A] \times FWT[B] = \left(\sum_{j|i=i}A_j\right)\left(\sum_{k|i=i}B_k\right) \]\[=\left(\sum_{j|k|i=i}A_jB_k\right) \]\[=\left(\sum_{g|i=i}\sum_{j|k=g}A_jB_k\right) \]

拆開 \(C\) 的定義式後,與上式形式相同。

證畢。

有了這個性質後我們接下來還需要解決兩個問題

1:已知 A 如何快速求 FWT[A]
2:已知 FWT[A] 如何逆向求 A

記 A_0 為 A 下標中最高位為 0 的部分,A_1 為 A 下標中最高位為 1 的部分。

\((G,K)\)表示將這兩個序列前後接起來。

\(A + B\) 為$$\left{ A_1+B_1 ,A_2+B_2,A_3+B_3\dots A_n+B_n\right}$$

\(A \cdot B\)

\[\left\{ A_1\times B_1 ,A_2\times B_2,A_3\times B_3\dots A_n\times B_n\right\} \]


\(2\leq|A|\)

\[FWT[A]=(FWT[A_0],FWT[A_0]+FWT[A_1]) \]

\(n=1\)時$$FWT[A]=A$$

\(2\leq|A|\)

\[IFWT[A]=(IFWT[A_0],IFWT[A_1]-IFWT[A_0]) \]

\(n=1\)時$$IFWT[A]=A$$

按子集規規模理解即可,和下邊的xor的數學歸納法證明相似

and 運算下

與運算同理

\(FWT[A]=\begin{cases}(FWT[A_0]+FWT[A_1],FWT[A_1])&2\leq n\\ A& n=1\end{cases}\)

\(IFWT[A]=\begin{cases}(IFWT[A_0]-IFWT[A_1],IFWT[A_1])&2\leq n\\ A& n=1\end{cases}\)

xor 運算下

突然就難了一個級別

定義FWT[A]如下定義

\(FWT[A]=\begin{cases}(FWT[A_0]+FWT[A_1],FWT[A_0]-FWT[A_1])&2\leq n\\ A& n=1\end{cases}\)

性質1:

\(FWT(A+B)=FWT(A)+FWT(B)\)

根據FWT[A]中每一維都是A中元素的線性組合可知

性質2:

\(FWT(A\bigoplus B)=FWT(A) \cdot FWT(B)\)

證明:
應用數學歸納法
\(n=1\)顯然成立。

\(FWT(A⊕ B)=FWT((A⊕ B)_0,(A⊕ B)_1)\)

\(=FWT(A0⊕B0+A1⊕B1,A0⊕B1+A1⊕B0)\)

\(=FWT(A0⊕B0+A1⊕B1+A0⊕B1+A1⊕B0,A0⊕B0+A1⊕B1-A0⊕B1-A1⊕B0)\)

\(=(FWT(A0⊕B0)+FWT(A1⊕B1)+FWT(A0⊕B1)+FWT(A1⊕B0),FWT(A0⊕B0)+FWT(A1⊕B1)-FWT(A0⊕B1)-FWT(A1⊕B0))\)

\(=(FWT(A0)\cdot FWT(B0)+FWT(A1)\cdot FWT(B1)+FWT(A0)\cdot FWT(B1)+FWT(A1)\cdot FWT(B0),FWT(A0)\cdot FWT(B0)+FWT(A1)\cdot FWT(B1)-FWT(A0)\cdot FWT(B1)-FWT(A1)\cdot FWT(B0))\)

\(=(FWT(A0+A1)\cdot FWT(B0+B1),FWT(A0-A1)\cdot FWT(B0-B1))\)
(將這個式子做點乘得到上面那步
\(=(FWT(A0+A1)+FWT(A0-A1))\cdot (FWT(B0-B1)+FWT(B0-B1))\)

\(=FWT(A) \cdot FWT(B)\)

\(\bigoplus\)拆解成點乘,就相當於數學歸納呼叫子問題

(證了這麼多終於得到正向變換le

考慮逆向變換

\(IFWT[A]=\begin{cases}(\frac{1}{2}\times (IFWT[A_0]+IFWT[A_1]),\frac{1}{2}\times (IFWT[A_0]-IFWT[A_1]))&2\leq n\\ A& n=1\end{cases}\)

證明:

\(IFWT(FWT(A))=IFWT((FWT(A0+A1),FWT(A0-A1))\)

\(=(IFWT(FWT(A0)),IFWT(FWT(A1)))\)

\(=(A0,A1)\)

\(=A\)
證畢

粘一下板子

P4717 【模板】快速莫比烏斯/沃爾什變換 (FMT/FWT)

#include<bits/stdc++.h>

using namespace std;

#define int long long
#define INF 1ll<<30
#define Int unsigned long long 

template<typename _T>
inline void read(_T &x)
{
	x=0;char s=getchar();int f=1;
	while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
	while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
	x*=f;
}

#define lowbit(x) (x&(-x))
#define gb(x) ((x-1)/T + 1)
#define gl(x) ((x-1)*T + 1)
#define pb push_back 
#define mod 998244353

const int np = (1<<17) + 5;

int A[np],B[np];
int A_[np],B_[np];
int c[np];
int n_,n;

inline int power(int a,int b)
{
	int res = 1;
	while(b)
	{
		if(b & 1) res = res * a % mod;
		a = a * a;
		a %= mod;
		b>>=1;
	}
	return res;
}

inline void mul(){for(int i=0;i<n;i++)c[i] = A_[i] * B_[i] % mod;}

inline void FWTor(int *g,int opt)
{
	for(int i=1;i < n;i <<= 1)
		for(int o = 2 * i , j =0 ;j < n;j += o)
			for(int k=0; k < i; k++)
			{
				g[j + k + i] += g[j + k] * opt;
				g[j + k + i] = (g[j + k + i] + mod)%mod ;				
			}	
}

inline void FWTand(int *g,int opt)
{
	for(int i=1;i<n;i<<=1)
		for(int o = 2 * i , j =0 ;j<n;j += o)
			for(int k = 0;k<i;k++)
			{
				g[j + k] += g[j + k + i] * opt;
				g[j + k] = (g[j +k] + mod) % mod;
			}
}

inline void FWTxor(int *g,int opt)
{
	for(int i=1;i<n;i<<=1)
		for(int o=2 * i , j = 0;j<n;j += o)
			for(int k=0;k<i;k++)
			{
				int x = g[j + k],y = g[j + k + i];
				g[j + k] = (x + y) % mod;
				g[j + k + i] = (x - y + mod) %mod;
				if(opt != 1)
				{
					(g[j + k] *= opt)%=mod;
					(g[j + k + i] *= opt)%=mod;
				}
			}
}

inline void Init()
{
	for(int i=0;i<n;i++) A_[i] = A[i] , B_[i] = B[i];
}

signed main()
{
	read(n_);
	n = 1<<n_;
	for(int i=0;i<n;i++) read(A[i]);
	for(int i=0;i<n;i++) read(B[i]);
	
	int inv = power(2,mod-2);
	for(int i=0;i<n;i++) A_[i] = A[i] , B_[i] = B[i];
	FWTor(A_,1);
	FWTor(B_,1);
	mul();
	FWTor(c,-1);
	for(int i=0;i<n;i++) printf("%lld ",c[i]);
	printf("\n");
	for(int i=0;i<n;i++) A_[i] = A[i] ,B_[i] = B[i];
	FWTand(A_,1);
	FWTand(B_,1);
	mul();
	FWTand(c,-1);
	for(int i=0;i<n;i++) printf("%lld ",c[i]);
	printf("\n");
	Init();
	FWTxor(A_,1);
	FWTxor(B_,1);
	mul();
	FWTxor(c,inv);
	for(int i=0;i<n;i++) printf("%lld ",c[i]);
	
	
//	FWTxor();
 }

CF449D Jzzhu and Numbers

寫了一個 \(O(n \times 3^n)\) 暴力 dp

顯然是過不了的,接下來有兩種解決方案:

1:降維容斥

2:FWT科技解決問題

\(f_i\)為最後與出來的結果至少是\(i\)\(g_i\)為最後與出來的結果恰好是\(i\)

那麼有

\[g_i=f_i-\sum_{j\&i=i}g_j \]

\[g_0=f_0-\sum_{i=1}^{2^n-1}g_i \]

\(g\) 繼續展開有

\[g_0=\sum_{i=0}^{2^n-1}f_i\times (-1)^{|i|} \]

現在我們考慮如何求 \(f\)

\(f_x = 2^s-1\),其中 \(s\)\(i\&x=x\)的數的個數

\[\sum_{i\&x=x}1 \]

這個東西好像可以上沃爾什變換,然後 \(FWT\) 即可

程式碼迴圈展開了一下(為了卡最優解

#include<bits/stdc++.h>

using namespace std;

#define int long long
#define INF 1ll<<30
#define Int unsigned long long 

template<typename _T>
inline void read(_T &x)
{
	x=0;char s=getchar();int f=1;
	while(s<'0'||s>'9') {f=1;if(s=='-')f=-1;s=getchar();}
	while('0'<=s&&s<='9'){x=(x<<3)+(x<<1)+s-'0';s=getchar();}
	x*=f;
}

#define lowbit(x) (x&(-x))
#define gb(x) ((x-1)/T + 1)
#define gl(x) ((x-1)*T + 1)
#define pb push_back 
#define Re register
#define MOD(x) (x = (x + mod)%mod)
const int mod = 1e9 + 7;

const int np = 3e6 + 5;

int dp[np],dp_[np];
int A_[np];
int G[np];
int bac[np] , bit[np];
int A[np],n_;
int a[np];

int power(int a,int b)
{
	Re int res = 1;
	while(b)
	{
		if(b & 1) res = a * res % mod;
		a = a * a;
		a %= mod;
		b>>=1;
	}
	return res;
}

inline void FWT(int *g)
{
	for(int i=1;i<n_;i<<=1)
		for(register int o = 2*i,j=0;j<n_;j+=o)
			for(register int k=0;k<i;k++)
			{
				int &d = g[j + k];
				d += g[j + k + i];
				d >= mod?d -mod:0;
			}
	
}

signed main()
{
	
	int n,maxn = 0;
	read(n);
	n_ = 1;
	for(int i=1;i<=n;i++) read(a[i]),bac[a[i]]++, maxn = max(maxn , a[i]);
	
	while(n_ <= maxn) n_ <<= 1;

	FWT(bac);
	
	bit[0] = 0;
	for(int i=1;i<n_;i++)
	{
		bit[i] = bit[i - lowbit(i)] + 1;
	}
	
	for(int i=0;i<n_;i++)
	{
		bit[i] = (bit[i]&1)?-1:1;
	}
	
	Re int i(-1),Ans1(0),Ans2(0),Ans3(0),Ans4(0),Ans5(0),Ans6(0),Ans7(0),Ans8(0);
	Re int f1(0),f2(0),f3(0),f4(0),f5(0),f6(0),f7(0),f8(0),Ans(0);
	for(;i + 8<n_;i+=8)
	{
		f1 = power(2,bac[i + 1]) - 1;
		f2 = power(2,bac[i + 2]) - 1;
		f3 = power(2,bac[i + 3]) - 1;
		f4 = power(2,bac[i + 4]) - 1;
		f5 = power(2,bac[i + 5]) - 1;
		f6 = power(2,bac[i + 6]) - 1;
		f7 = power(2,bac[i + 7]) - 1;
		f8 = power(2,bac[i + 8]) - 1;
		Ans1 += f1 * bit[i + 1];
		MOD(Ans1);
		Ans2 += f2 * bit[i + 2];
		Ans3 += f3 * bit[i + 3];
		Ans4 += f4 * bit[i + 4];
		Ans5 += f5 * bit[i + 5];
		Ans6 += f6 * bit[i + 6];
		Ans7 += f7 * bit[i + 7];
		Ans8 += f8 * bit[i + 8];
		Ans += Ans1 + Ans2 + Ans3 + Ans4 + Ans5 + Ans6 + Ans7 + Ans8;
		Ans1 =Ans2 = Ans3 = Ans4 = Ans5 =Ans6 = Ans7 = Ans8 = 0;
		MOD(Ans);
	}
	i++;
	for(Re int f(0);i<n_;i++)
	{
		f = power(2,bac[i])-1;
		f *= bit[i];
		Ans += f;
		if(f < 0) Ans = (Ans + mod) %mod;
		else Ans %= mod;
	}
	Ans += Ans1 + Ans2 + Ans3 + Ans4 + Ans5 + Ans6 + Ans7 + Ans8;
	cout<<Ans;
	
 }

\(End\)