1. 程式人生 > >【學習筆記】Berlekamp-Massey演算法

【學習筆記】Berlekamp-Massey演算法

【演算法簡介】

  • Berlekamp-Massey演算法,常簡稱為BM演算法,是用來求解一個數列的最短線性遞推式的演算法。
  • BM演算法可以在\(O(N^2)\)的時間內求解一個長度為\(N\)的數列的最短線性遞推式。
  • 在當今OI競賽界,尚沒有很多BM演算法的應用,但在一些輸入的數很少的題目中,BM能夠成為發掘題目性質的一大助力,甚至有可能直接解出答案的線性遞推式,不失為一種不錯的工具。

【演算法流程】

  • 對於數列\(\{a_1,a_2,a_3,...,a_n\}\),我們稱數列\(\{r_1,r_2,r_3,...,r_m\}\)為其線性遞推式當且僅當\(a_i=\sum_{j=1}^{m}r_j*a_{i-j}\)對於任意\(m+1≤i≤n\)均成立。
  • 若數列\(\{a_1,a_2,a_3,...,a_n\}\)的線性遞推式\(\{r_1,r_2,r_3,...,r_m\}\)還滿足\(m\)是所有該數列的線性遞推式中最小的,則稱\(\{r_1,r_2,r_3,...,r_m\}\)為數列\(\{a_1,a_2,a_3,...,a_n\}\)的最短線性遞推式。
  • 現在考慮我們已經求得了\(\{a_1,a_2,a_3,...,a_{i-1}\}\)的最短線性遞推式\(\{r_1,r_2,r_3,...,r_m\}\),如何求得\(\{a_1,a_2,a_3,...,a_i\}\)的最短線性遞推式。
  • 定義\(\{a_1,a_2,a_3,...,a_{i-1}\}\)的
    最短線性遞推式\(\{r_1,r_2,r_3,...,r_m\}\)為當前遞推式,記遞推式被更改的次數為\(cnt\),第\(i\)次更改後的遞推式為\(R_i\),特別地,定義\(R_0\)為空,那麼當前遞推式應當為\(R_{cnt}\)。
  • 記\(delta_i=a_i-\sum_{j=1}^{m}r_j*a_{i-j}\),其中\(\{r_1,r_2,r_3,...,r_m\}\)為當前遞推式,顯然若\(delta_i=0\),那麼當前遞推式就是\(\{a_1,a_2,a_3,...,a_i\}\)的最短線性遞推式。
  • 否則,我們認為\(R_{cnt}\)在\(a_i\)處出錯了,定義\(fail_i\)為\(R_i\)最早的出錯位置,則有\(fail_{cnt}=i\)。
  • 考慮對\(R_{cnt}\)進行修改,使其變為\(R_{cnt+1}\),並在\(a_i\)處同樣成立。
  • 若\(cnt=0\),這意味著\(a_i\)是序列中第一個非零元素,我們可以令\(R_{cnt+1}=\{0,0,0,...,0\}\),即用\(i\)個0填充線性遞推式,此時由於不存在\(j\)使得\(m+1≤j≤i\),因此\(R_{cnt+1}\)顯然為\(\{a_1,a_2,a_3,...,a_i\}\)的線性遞推式,並且由於\(a_i\)是序列中第一個非零元素,不難證明\(R_{cnt+1}\)也是\(\{a_1,a_2,a_3,...,a_i\}\)的最短線性遞推式。
  • 否則,即\(cnt>0\),考慮\(R_{cnt-1}\)出錯的位置\(fail_{cnt-1}\),記\(mul=\frac{delta_i}{delta_{fail_{cnt-1}}}\)。
  • 我們希望得到數列\(R'=\{r'_1,r'_2,r'_3,...,r'_{m'}\}\),使得\(\sum_{j=1}^{m'}r'_j*a_{k-j}=0\)對於任意\(m'+1≤k≤i-1\)均成立,並且\(\sum_{j=1}^{m'}r'_j*a_{i-j}=delta_i\)。如果能夠找到這樣的數列\(R'\),那麼令\(R_{cnt+1}=R_{cnt}+R'\)即可(其中\(+\)定義為各位分別相加)。
  • 數列\(R'\)可以是下述數列:\(\{0,0,0,...,0,mul,-mul*R_{cnt-1}\}\)即填充\(i-fail_{cnt-1}-1\)個零,然後將數列\(\{1,-R_{cnt-1}\}\)的\(mul\)倍放在後面。此時有\(\sum_{j=1}^{m'}r'_j*a_{i-j}=delta_{fail_{cnt-1}}*mul=delta_i\),並且\(\sum_{j=1}^{m'}r'_j*a_{k-j}=0\)對於任意\(m'+1≤k≤i-1\)均成立
  • 故令\(R_{cnt+1}=R_{cnt}+R'\)即可。
  • 在最壞情況下,我們可能需要對數列進行\(O(N)\)次修改,因此該演算法的時間複雜度為\(O(N^2)\)。

【一組例項】

  • 以數列\(\{1,2,4,9,20,40,90\}\)為例,我們來具體地理解一下演算法流程。
  • 初始時,我們有\(R_0=\{\},cnt=0\)。
  • \(i=1\)時,將\(a_1=1\)代入遞推式,得到\(delta_1=1\),\(R_0\)在\(i=1\)時出錯,記\(fail_0=1\)。由於此時\(cnt=0\),我們將遞推式修改為\(R_1=\{0\}\)。
  • \(i=2\)時,將\(a_2=2\)代入遞推式,得到\(delta_2=2\),\(R_1\)在\(i=2\)時出錯,記\(fail_1=2\)。此時\(mul=2\),可以構造得到\(R'=\{2\}\),遞推式被修改為\(R_2=\{2\}\)。
  • \(i=3\)時,將\(a_3=4\)代入遞推式,得到\(delta_3=0\),\(R_2\)沒有出錯。
  • \(i=4\)時,將\(a_4=9\)代入遞推式,得到\(delta_4=1\),\(R_2\)在\(i=4\)時出錯,記\(fail_2=4\)。此時\(mul=0.5\),可以構造得到\(R'=\{0,0.5,0\}\),遞推式被修改為\(R_3=\{2,0.5,0\}\)。
  • \(i=5\)時,將\(a_5=20\)代入遞推式,得到\(delta_5=0\),\(R_3\)沒有出錯。
  • \(i=6\)時,將\(a_6=40\)代入遞推式,得到\(delta_6=-4.5\),\(R_3\)在\(i=6\)時出錯,記\(fail_3=6\)。此時\(mul=-4.5\),可以構造得到\(R'=\{0,-4.5,9\}\),遞推式被修改為\(R_4=\{2,-4,9\}\)。
  • \(i=7\)時,將\(a_7=90\)代入遞推式,得到\(delta_7=9\),\(R_4\)在\(i=7\)時出錯,記\(fail_4=7\)。此時\(mul=-2\),可以構造得到\(R'=\{-2,4,1,0\}\),遞推式被修改為\(R_5=\{0,0,10,0\}\)。
  • 因此以數列\(\{1,2,4,9,20,40,90\}\)的遞推式即為\(R_5=\{0,0,10,0\}\)。

【程式碼】

  • 以下程式碼實現了求解給定\(N\)元數列在實數域上的最短線性遞推式。
  • 顯然,BM演算法只需要數域中每個非零元素均存在乘法逆元即可實現,讀者不妨自行實現一下在模質數意義下的BM演算法。
#include<bits/stdc++.h>
using namespace std;
const int MAXN = 2005;
const double eps = 1e-8;
template <typename T> void read(T &x) {
	x = 0; int f = 1;
	char c = getchar();
	for (; !isdigit(c); c = getchar()) if (c == '-') f = -f;
	for (; isdigit(c); c = getchar()) x = x * 10 + c - '0';
	x *= f;
}
int cnt, fail[MAXN];
double val[MAXN], delta[MAXN];
vector <double> ans[MAXN];
int main() {
	int n; read(n);
	for (int i = 1; i <= n; i++)
		scanf("%lf", &val[i]);
	for (int i = 1; i <= n; i++) {
		double tmp = val[i];
		for (unsigned j = 0; j < ans[cnt].size(); j++)
			tmp -= ans[cnt][j] * val[i - j - 1];
		delta[i] = tmp;
		if (fabs(tmp) <= eps) continue;
		fail[cnt] = i;
		if (cnt == 0) {
			ans[++cnt].resize(i);
			continue;
		}
		double mul = delta[i] / delta[fail[cnt - 1]];
		cnt++; ans[cnt].resize(i - fail[cnt - 2] - 1);
		ans[cnt].push_back(mul);
		for (unsigned j = 0; j < ans[cnt - 2].size(); j++)
			ans[cnt].push_back(ans[cnt - 2][j] * -mul);
		if (ans[cnt].size() < ans[cnt - 1].size()) ans[cnt].resize(ans[cnt - 1].size());
		for (unsigned j = 0; j < ans[cnt - 1].size(); j++)
			ans[cnt][j] += ans[cnt - 1][j];
	}
	for (unsigned i = 0; i < ans[cnt].size(); i++)
		cout << ans[cnt][i] << ' ';
	return 0;
}