1. 程式人生 > >zoj 3494 AC自動機+數位DP

zoj 3494 AC自動機+數位DP

這道題說實話我到現在也不能完整的把這份程式碼寫出來,大部分都是抄kuangbin部落格裡的那份程式碼,其實這份程式碼讓我讓我一個人看我差不多能看懂,但讓我寫我感覺我根本寫不了,首先我們對模式串建一個AC自動機,然後構造出來一些不可達的點,再寫成一個矩陣bcd[i][j]表示在AC自動機上第i個點加上j後能到達的點,j是十進位制,從0到9,然後再在bcd[i][j]上跑數位DP,我還回頭寫了兩道數位DP想了想怎麼寫,這裡面的數位DP注意是可以有前導0的,dp[len][i]表示長度為len在自動機上第J個節點的合法數量。

#include<bits/stdc++.h>
using namespace std;
using LL = int64_t;
const int maxnode=1e6+5;
const int sigma_size=4;
char s[maxnode];
const LL mod=1e9+9;
struct Node {
    int son[sigma_size];
    int val,fail;
}ch[maxnode];
int bcd[2010][10];
struct AC {
    int sz=1;
    queue<int>Q;
    void init(int x) {ch[x].fail=ch[x].val=0;memset(ch[x].son,0,sizeof(ch[x].son));}
    int idx(char c) {return c-'0';}

    void insert(char s[],int v) {
        int u=0,n=strlen(s);
        for(int i=0;i<n;i++) {
            int c=idx(s[i]);
            if(!ch[u].son[c]) {
                init(sz);
                ch[u].son[c]=sz++;
            }
            u=ch[u].son[c];
        }
        ch[u].val=v;
    }

    int change(int pre,int num) {
        if(ch[pre].val) return -1;
        int now=pre;
        for(int i=3;i>=0;i--) {
            if(ch[ch[now].son[(num>>i)&1]].val) return -1;
            now=ch[now].son[(num>>i)&1];
        }
        return now;
    }

    void build() {
        for(int i=0;i<sigma_size;i++) if(ch[0].son[i]) Q.push(ch[0].son[i]);
        while(!Q.empty()) {
            int now=Q.front();Q.pop();
            int fail=ch[now].fail;
            if(ch[fail].val) ch[now].val|=ch[fail].val;
            for(int i=0;i<sigma_size;i++) {
                int nxt=ch[now].son[i];
                if(nxt) {
                    ch[nxt].fail=ch[fail].son[i];
                    Q.push(nxt);
                }
                else ch[now].son[i]=ch[fail].son[i];
            }
        }
        for(int i=0;i<sz;i++) {
            for(int j=0;j<10;j++) {
                bcd[i][j]=change(i,j);//bcd[i][j]表示在自動機上第i個節點,加上j以後到達的節點
            }
        }
    }
};

LL dp[205][2005],cnt[205];

LL dfs(int len,int pos,bool flag, bool zero) {
    if(len==-1) return 1;
    if(flag==false&&dp[len][pos]!=-1) return dp[len][pos];
    LL ans=0;
    if(zero) {
        ans=(ans+dfs(len-1,pos,flag&&cnt[len]==0,true))%mod;
    }
    else {
        if(bcd[pos][0]!=-1) ans=(ans+dfs(len-1,bcd[pos][0],flag&&cnt[len]==0,false))%mod;
    }
    int ends=(flag?cnt[len]:9);
    for(int i=1;i<=ends;i++) {
        if(bcd[pos][i]!=-1) {
            ans=(ans+dfs(len-1,bcd[pos][i],flag&&i==ends,false))%mod;
        }
    }
    if(!flag&&!zero) dp[len][pos]=ans;
    return ans;
}

LL solve(char s[]) {
    memset(cnt,0,sizeof(cnt));
    int len=strlen(s);
    for(int i=0;i<strlen(s);i++) cnt[i]=s[len-1-i]-'0';
    return dfs(len-1,0,1,1);
}

int main()
{
    ios::sync_with_stdio(0);
    cin.tie(0);
    int T;cin>>T;
    while(T--) {
        AC ans;ans.init(0);
        int n;cin>>n;
        for(int i=1;i<=n;i++) {
            cin>>s;
            ans.insert(s,1);
        }
        ans.build();
        memset(dp,-1,sizeof(dp));
        cin>>s;
        for(int i=strlen(s)-1;i>=0;i--) {
            if(s[i]>'0') {
                s[i]--;
                break;
            }
            else s[i]='9';
        }
        LL res=-solve(s);
        cin>>s;
        res=(res%mod+solve(s)%mod+mod)%mod;
        cout<<res <<"\n";
    }
    return 0;
}