1. 程式人生 > 實用技巧 >Luogu P3990 [SHOI2013]超級跳馬

Luogu P3990 [SHOI2013]超級跳馬

矩陣優化DP

原題連結

矩陣乘法幾乎不會,我瞎扯證明了1個小時才差不多搞明白怎麼弄。

  1. \(DP\)

    \(\large dp_{i,j}\)表示跳到第\(i\)行第\(j\)列的方案數。

    於是有了一個初步的狀態轉移方程(假設\(i,j\)都不會越界):

    \(\large dp_{i,j}=dp_{i-1,j-1}+dp_{i,j-1}+dp_{i+1,j-1}+dp_{i-1,j-3}+dp_{i-1,j-5}+...\ ...\)

    很明顯,它複雜度高,不方便求。

    然後手玩一下你就會發現後面那一大串(\(dp_{i-1,j-3}+...\ ...\) 以及往後的)跟\(dp_{i,j-2}\)

    相等。

    證明:

    首先,奇數+偶數(2)=奇數,小學知識

    因為馬每次只能橫向跳奇數格,所以能一次性達到\((i\ ,\ j-2)\)的點同時也能一次性達到\((i\ ,\ j)\),且沒有其他點(出去第\(j-2\)列上的三個點)能到達\((i\ ,\ j)\)

    然後狀態轉移方程就變成了

    \(\large dp_{i,j}=dp_{i-1,j-1}+dp_{i,j-1}+dp_{i+1,j-1}+dp_{i,j-2}\)

    複雜度是$ O(nm)$的,T飛。

  2. 矩陣優化:

    轉移矩陣的大小與\(n\)有關。

    如果\(n=3\),轉移矩陣是這樣噠:

    然後矩陣快速冪求一下就好了。

  3. 計算答案

    你以為怎麼著就結束了嗎?

    其實,還有一個細節問題沒有注意到:

    \(i=1,j=3\)時,迴歸轉移方程式:

    \(\large dp_{1,3}=dp_{1,2}+dp_{1,2}+dp_{1,1}\)

    然而這時候\(dp_{1,1}\)卻並不是由前面的點轉移過來的,而是特殊初值!!!

    也就是說對於\(dp_{i,j}\),答案都恰好會多算\(dp_{i,j-2}\)(這裡如果不明白可以自己手玩幾組小資料)。

    綜上所述,記錄答案時,應該求\(dp_{n,m}-dp_{n,m-2}\)

    當然也可以求\(dp_{n,m-1}+dp_{n-1,m-1}\)

多說一句:為何看著各位大佬的矩陣初值都是對角線為1啊,光把\(dp_{1,1}\)

初始化為1,計算答案時只用第一行不就行嗎?雖然沒多大區別

對了,答案別忘了%30011。

code(沒寫註釋,個人認為上面寫的比較詳細了):


#include<bits/stdc++.h> 
#define p 30011
using namespace std;
int read()
{
	int xsef = 0,yagx = 1;char cejt = getchar();
	while(cejt < '0'||cejt > '9'){if(cejt == '-')yagx = -1;cejt = getchar();}
	while(cejt >= '0'&&cejt <= '9'){xsef = (xsef << 1) + (xsef << 3) + cejt - '0';cejt = getchar();}
	return xsef * yagx;
}
int n,m;
struct node{
	int a[110][110];
	node(){
		memset(a, 0, sizeof(a));
	}
}b,c,d;
node operator * (node x, node y){
	node z;
	for(int i = 1;i <= n * 2;i++){
		for(int j = 1;j <= n * 2;j++){
			for(int k = 1;k <= n * 2;k++){
				z.a[i][j] = (z.a[i][j] + x.a[i][k] * y.a[k][j]) % p;
			}
		}
	}
	return z;
}
void build(){
	for(int i=1;i<=n;i++){
			if(i!=1)
				b.a[i][i-1]=1;
			b.a[i][i]=1;
			if(i!=n)
				b.a[i][i+1]=1;
		}
	for(int i=1;i<=n;i++){
		b.a[i][i+n]=1;
	}
	for(int i=1;i<=n;i++){
		b.a[i+n][i]=1;
	}
}
void ksm(){
	while(m){
		if(m&1)
			c=c*b;
		b=b*b;
		m>>=1;
	}
}
signed main(){
	n=read(),m=read();
	build();
	d = b;
	m-=2;
	for(int i=1;i<=n*2;i++)
		c.a[i][i]=1;
	ksm();
	int ans = -c.a[1][n+n];
	c=c*d;
	ans += c.a[1][n];
	printf("%d",ans + p % p);
	return 0;
}