1. 程式人生 > >bzoj1212 [HNOI2004]L語言(DP+AC自動機)

bzoj1212 [HNOI2004]L語言(DP+AC自動機)

題目

求模式串對主串的最大匹配長度。

題解

DP+AC自動機
設f[i]=true表示前i位可以匹配出來,那麼轉移方案就是揹包的,f[i]|=f[i-len_j](要求:主串的字尾與模式串j完全一致),其中len[j]是一個模式串的長度。
如果要是大(bao)力DP的話,顯然會很慢,因為我們要配對每一個模式串。
顯然這種字尾配字首的問題應當交由AC自動機來處理。
把匹配串逐一insert到字典樹中,同時標記一下結尾,求一個fail指標,然後我們就可以來一波DP了。DP時,對於長度為i的字首,我們通過跳fail列舉所有 模式串的字首 與當前主串字尾相同的樹節點,一不小心碰到標記過的,說明滿足這個主串以i結尾的字尾與這個模式串匹配,此時符合DP方程的轉移條件,考慮更新f[i]。

程式碼

#include<cstdio>
#include<cstring>
#include<algorithm>
using namespace std;
const int maxn=30;
const int maxp=10*maxn;

char s[(const int)1e6+10];int len;
struct Tr{int x,fail,dep,son[25];}tr[maxp];int cnt=1,root=1;

bool flag[maxp];
void insert()
{
    int x=root;
    for(int i=1;i<=len;i++)
    {
        int k=s[i]-'a';
        if(tr[x].son[k]==0)
        {
            tr[x].son[k]=++cnt;
            tr[cnt].dep=tr[x].dep+1;
        }
        x=tr[x].son[k];
    }
    flag[x]=true;
}

int head,tail,q[maxp];
void getfail()
{
    head=0,tail=1;q[0]=root;
    while(head<tail)
    {
        int x=q[head++];
        for(int k=0;k<26;k++)//debug 一共26個字母! 
        {
            int y=tr[x].son[k];
            if(y==0) continue;
            else if(x==root) tr[y].fail=root;
            else
            {
                int p=tr[x].fail;//debug int p=x;
                while(p!=root && !tr[p].son[k]) p=tr[p].fail;
                //debug tr[y].fail=p;
                if(tr[p].son[k]) tr[y].fail=tr[p].son[k];
                else tr[y].fail=root;
            }
            q[tail++]=y;
        }
    }
}

int f[(const int)1e6+10];//f[i]表示長度為i的字首能否被翻譯 
void solve(int id)
{
    int x=root,ans=0;
    f[0]=id;
    for(int i=1;i<=len;i++)
    {
        int k=s[i]-'a';
        while(x!=root && !tr[x].son[k]) x=tr[x].fail;
        x= !tr[x].son[k]?root:tr[x].son[k] ;//debug x=tr[x].son[k];
        
        for(int p=x;p!=root;p=tr[p].fail)
        {
            if(flag[p] && i-tr[p].dep>=0 && f[i-tr[p].dep]==id){f[i]=id;break;}//DP轉移
        }
        if(f[i]==id) ans=i;
    }
    printf("%d\n",ans);
}

int main()
{
    int n,m;
    scanf("%d%d",&n,&m);
    tr[root].dep=0;
    for(int i=1;i<=n;i++)
    {
        scanf("%s",s+1);
        len=strlen(s+1);
        insert();
    }
    getfail();
//    for(int i=1;i<=cnt;i++)
//    {
//        printf("fail %d  : %d\n",i,tr[i].fail);
//    }
    for(int i=1;i<=m;i++)
    {
        scanf("%s",s+1);
        len=strlen(s+1);
        solve(i);
    }
    return 0;
}