1. 程式人生 > 實用技巧 >BZOJ-1009 [HNOI2008]GT考試(矩陣快速冪加速dp+KMP)

BZOJ-1009 [HNOI2008]GT考試(矩陣快速冪加速dp+KMP)

題目描述

  准考證號為 \(n\) 位數 \(X_1,X_2,\cdots,X_n(0\leq X_i\leq 9)\),不希望准考證號上出現不吉利的數字。不吉利數字 \(A_1,A_2,\cdots,A_m(0\leq A_i\leq 9)\)\(m\) 位,不出現是指 \(X_1,X_2,\cdots,X_n\) 中沒有恰好一段等於 \(A_1,A_2,\cdots,A_m\)\(A_1\)\(X_1\)可以為 \(0\)。求合法方案數,答案對 \(p\) 取模。

  資料範圍:\(n\leq 10^9,m\leq 20,p\leq 1000\)

分析

  設 \(dp[i][j]\) 表示准考證的前 \(i\)

位數中的最後 \(j\) 位與不吉利的數字的前 \(j\) 位相同時,前 \(i\) 位的合法方案數,且 \(dp[i][j]\) 的每種方案都不含長度大於 \(j\) 且與不吉利數字的字首相同的字尾(為了避免重複)。則答案為:\(dp[n][0]+dp[n][1]+\cdots+dp[n][m-1]\)

  考慮如何進行狀態轉移,\(dp[i][j]\) 只能由 \(dp[i-1][k]\) 轉移過來,相當於填完第 \(i-1\) 位後,長為 \(k\) 的字尾後面新新增一位 \(num\),此時這個有 \(i\) 位的數字與不吉利數字字首相同的最長的字尾的長度為 \(j\)

  狀態轉移方程為:

\[dp[i][j]=\sum_{k=0}^{m-1}dp[i-1][k]\times f[k][j] \]

  上式的 \(f[k][j]\) 表示當前准考證的長為 \(k\) 的字尾已經匹配了不吉利數字長為 \(k\) 的字首,有多少種新增一個數字 \(num\) 的方法,能使匹配長度變為 \(j\)

  舉個例子:假設不吉利數字是 \(123124\),則 \(dp[i][3]=dp[i-1][2]+dp[i-1][5]\),因為 \(dp[i-1][2]\) 的字尾\(\cdots12\) 不能是 \(\cdots12312\),所以還需要 \(dp[i-1][5]\) 來補充;假設不吉利數字是 \(123123\)

,則 \(dp[i][3]=dp[i-1][2]\),因為 \(dp[i][3]\) 末尾的 \(\cdots123\) 不能是 \(\cdots123123\)

  因為我們現在已經知道不吉利數字是什麼,所以 \(f[k][j]\) 矩陣是固定的,可以用 \(\text{KMP}\) 演算法預處理 \(\text{Next}\) 陣列,然後列舉長度 \(k\) 和新增的數字 $num $,沿著 \(\text{Next}\) 陣列往前跳找到來找到能轉移到的 \(j\),從而預處理出 \(f[k][j]\) 陣列。

  可以發現這個狀態轉移方程和矩陣乘法的的式子非常像,用矩陣快速冪加速 $dp $ 即可,時間複雜度 \(O(m^3\log n)\)

程式碼

#include<bits/stdc++.h>
using namespace std;
int n,m,mod,Next[30];
char s[30];
struct matrix
{
    int mat[30][30];
    matrix()
    {
        memset(mat, 0, sizeof(mat));
    }
}A;
matrix mul(matrix A,matrix B)
{
    matrix ans;
    for(int i=0;i<m;i++)
        for(int j=0;j<m;j++)
            for(int k=0;k<m;k++)
                ans.mat[i][j]=(ans.mat[i][j]+A.mat[i][k]*B.mat[k][j])%mod;
    return ans;
}
matrix matrix_pow(matrix a,int b)
{
    matrix ans;
    for(int i=0;i<m;i++)
        for(int j=0;j<m;j++)
            ans.mat[i][j]=(i==j);
    while(b)
    {
        if(b&1)
            ans=mul(ans,a);
        a=mul(a,a);
        b>>=1;
    }
    return ans;
}
void get_next()
{
    Next[1]=0;
    int j=0;
    for(int i=2;i<=m;i++)
    {
        while(j>0&&s[i]!=s[j+1])
            j=Next[j];
        if(s[i]==s[j+1])
            j++;
        Next[i]=j;
    }
}
matrix KMP()
{
    get_next();
    matrix ans;
    for(int i=0;i<=m-1;i++)
    {
        for(char num=0;num<=9;num++)
        {
            int j=i;
            while(j>0&&num!=(int)(s[j+1]-'0'))
                j=Next[j];
            if(num==(int)(s[j+1]-'0'))
                j++;
            ans.mat[i][j]++;
        }
    }
    return ans;
}
matrix f;
int main()
{
    cin>>n>>m>>mod;
    scanf("%s",s+1);
    f=matrix_pow(KMP(),n);
    int ans=0;
    for(int i=0;i<=m-1;i++)
        ans=(ans+f.mat[0][i])%mod;
    cout<<ans<<endl;
    return 0;
}