1. 程式人生 > 實用技巧 >【Luogu P5343】【XR-1】分塊

【Luogu P5343】【XR-1】分塊

題目連結:

題目

題目大意:

給定兩個陣列 \(A,B\),現在要給 \(n\) 分組,每塊的長度必須在兩個陣列的交集裡(即 \(\sum_{i=1}a_i=n\quad(a_i\in A\cap B)\)),求方案數。

正文:

很明顯的一個 DP,設 \(f_i\) 表示長度為 \(i\) 的分組的方案數,那麼這麼轉移:

\[f_i=\sum_{j=1}^{m}f_{i-a_j} \]

\(m\) 表示最長的可能的塊長。

時間複雜度:\(O(nm)\)\(n\leq10^{18}\),這玩意過不去。

考慮用矩陣乘法加速遞推。

按照矩乘加速遞推的尿性,我們可以用一個 \(m\times 1\) 的矩陣來表示狀態:

\[\begin{bmatrix} f_{m}\\ f_{m-1}\\ \vdots\\ f_1 \end{bmatrix}\]

要想辦法轉移到 \(\begin{bmatrix} f_{m+1}\\ f_{m}\\ \vdots\\ f_3\\ f_2 \end{bmatrix}\),必須再構造一個 \(m\times m\) 的轉移矩陣。

如果要轉移到 \(f_{m+1}\),根據轉移方程 \(f_i=\sum_{j=1}^{m}f_{i-a_j}\) 和矩陣乘法的運算可以推斷轉移矩陣的第一行(轉移到 \(f_{m+1}\))肯定和 \(A\cap B\) 相關。

拿樣例為例。

樣例中 \(A\cap B = \{1,2\}\),轉移矩陣第一行相應的位置就是 \(1\),比如樣例的轉移矩陣的第一行就是 \(\begin{bmatrix}1&1&0&0&\cdots&0\end{bmatrix}\)。因為這樣,\(f_{m+1}\) 就能被轉移得到:

\[\begin{bmatrix} f_{m}\\ f_{m-1}\\ \vdots\\ f_1 \end{bmatrix} \times\begin{bmatrix}1&1&0&0&\cdots&0\end{bmatrix}=\begin{bmatrix}f_{m+1-1}\times 1+f_{m+1-2}\times 1+f_{m+1-2}\times0+\cdots+f_{1+1-1}\times0\end{bmatrix}=\begin{bmatrix}f_{m+1}\end{bmatrix}\]

接下來 \(m-1\) 行就更好辦了,由於目標矩陣第 \(2\) 到第 \(m\) 行就是初始矩陣的第 \(1\) 到第 \(m-1\) 行,那轉移矩陣就是:

\[\begin{bmatrix}1&1&\cdots&0&0\\ 1&0&\cdots&0&0\\ 0&1&\cdots&0&0\\ \vdots&\vdots&\ddots&\vdots&\vdots\\ 0&0&\cdots&1&0\end{bmatrix}\]

再套個矩陣快速冪就A了。

程式碼:


int f[N];
ll n;
bool PR[N], NF[N];

struct matrix
{
	ll mat[N][N];
	int n, m;
	matrix(){memset(mat, 0, sizeof mat);}
	inline ll* operator [] (int b) { return mat[b];}
}F, stp;

inline matrix operator*(matrix &a, matrix &b)
{
	matrix c; c.n = a.n, c.m = b.m;
	for (int i = 1; i <= a.n; i++)
		for (int j = 1; j <= b.m; j++)
			for (int k = 1; k <= a.m; k++)
				c[i][j] = (c[i][j] + (a[i][k] * b[k][j]) % mod) % mod;
	return c;
}

matrix qpow(matrix stp, ll b)
{
	matrix ans; ans.n = ans.m = stp.n;
	for (int i = 1; i <= ans.n; i++)
			ans[i][i] = 1;
	for (; b; b >>= 1)
	{
		if(b & 1) ans = ans * stp; 
		stp = stp * stp;
	}
	return ans;
}

int m;

int main()
{
	scanf ("%lld", &n);
	int pr;scanf ("%d", &pr);
	for (int i = 1, x; i <= pr; i++)
		scanf ("%d", &x), PR[x] = 1;
	int nf;scanf ("%d", &nf);
	for (int i = 1, x; i <= nf; i++)
		scanf ("%d", &x), NF[x] = (1 & PR[x]), m = (NF[x] == 1? max(x, m): m);
	stp.n = stp.m = F.n = m, F.m = 1;


	f[0] = 1;
	for (int i = 1; i < m; i++)
	{
		for (int j = 1; j <= i; j++)
			if(NF[j]) f[i] = (f[i] + f[i - j]) % mod;
		F[m - i][1] = f[i];
	}
	F[m][1] = 1;


	for (int i = 1; i <= m; i++)
		stp[1][i] = NF[i], 
		stp[i + 1][i] = 1;
	stp[m + 1][m] = 0;


	F = qpow(stp, n - m + 1ll) * F;
	printf("%lld", F[1][1]);
	return 0;
}