1. 程式人生 > 實用技巧 >芝士:字尾自動機(SAM)

芝士:字尾自動機(SAM)

背景

考慮能否有一種資料結構能接受一個固定字串所有的子串

雖然其的名稱為字尾自動機(SAM)

原理

考慮**每一個子串 **有一個集合\(endpos\),表示為這個子串在哪些位置上出現過,將這些位置的最後一個字元所匹配的位置拿出來

比如\(aababa\),其中\(ba\)\(endpos=\{4,6\}\)

我們稱每一個\(endpos\)為一個類

引理1

如果有一個子串\(t\)和一個子串\(s\),如果\(endpos_t\in endpos_s\),那麼一定有\(s\)\(t\)的一個字尾,這應該很容易理解吧

引理2

規定\(len_s<len_t\)

將引理1反過來,如果\(s\)

\(t\)的一個字尾,那麼一定有\(endpos_t\in endpos_s\),應該這也是很容易理解的

反之,如果\(s\)不是\(t\)的一個字尾,那麼$endpos_t\bigcap endpos_s=\varnothing $

考慮反證法,如果\(s\)不是\(t\)的一個字尾,且\(endpos_t\bigcap endpos_s\neq \varnothing\)

\(endpos_t\bigcap endpos_s=A\)

那麼如果在\(A_i\)位置兩者的\(endpos\)是一樣的,那麼\([A_i-len_t+1,A_i]\)即為\(t\),\([A_i-len_s+1,A_i]\)

即為\(s\),那麼\(s\)必為\(t\)的一個字尾,故假設不成立

引理3

對於相同的\(endpos\),考慮設其中長度最小的一個為\(s\),最長的一個為\(t\)

這裡單獨將\(t\)拿出來,只有這裡的下標是以t為基準

那麼有\(\forall i\in [1,len_t-len_s+1]\),有\([i,len_t]\)\(endpos\)\(t\)\(endpos\)是一樣的

比如對於一個任意的類,其中最長的為\(aababbb\),最短的為\(bbb\)

那麼\(aababbb,ababbb,babbb,abbb,bbb\),這5個子串的\(endpos\)是相同的

引理4

\(endpos\)不同的類最多隻有\(n\)個(這裡指的是\(n\)級別)

對於一個等價類,考慮往前面加只加一個字元,必然會導致\(endpos\)裂開

但是可以保證的是,所有裂開之後的\(endpos\)是沒有交集的,同時,這些$endpos \(中的值一定是來源於原來的\)endpos$,

基於此,如果有兩個\(endpos_s\)\(endpos_t\),滿足\(endpos_s\bigcap endpos_t=\varnothing\),那麼其裂出來的\(endpos\)之間一定是兩兩之間交集為空

也就是指對於子串的變化,實際上就是將\(endpos\)進行分裂

最初的\(endpos\)\(\{1,2\ldots n\}\),空串

利用線段樹的思想,很容易得到其所有的節點數不超過\(2n\)

也就是指不同\(endpos\)不會超過\(2n\)

引理5

考慮一個\(endpos\)的最長的字串長度為\(max\),其最短的字串長度為\(min\)

那麼有\(min_u=max_{fa_u}+1\),這裡的邊是指的是用引理4中提到的裂開和沒裂開的\(endpos\)之間的連邊

這應該也是顯然的吧

現在考慮怎麼去完成一個自動機應該有的功能,即能接受所有的子串

考慮現在已經將\(endpos\)將整個樹建了起來

比如原串為\(abaaab\)

也許長成這個樣子

其中點表示\(endpos\)的集合

考慮新增一些邊進去,使得到節點u的路徑都覆蓋所有的\(endpos\)所代表的字串,可以證明,其增加的邊數也是\(n\)級別的,

證明:@!(#@!*(……%!(@#

實際上是因為直接筆者聽證明的時候。。。。。

構造

好,現在已經知道了SAM大概的原理了,現在考慮怎麼在一個優秀的時間複雜度內構造出SAM,

不會吧,不會真的有人想用上面的原理直接構造吧

根據巨佬們的一次次的探索,SAM大概長成這個樣子

最下面的一行節點表示所有的字首

現在考慮,

如果巨佬給了你一個按\([1,n-1]\)已經建好的SAM,現在你怎麼新增一個字串進去,使其依然是一個SAM,

考慮到上面所說的原理,這裡只討論最基本的自動機

每一個點需要維護3個資訊,\(tre[i].len,tre[i].fa,tre[i].ch[]\)

這裡的fa陣列是用endpos建出來的樹的邊

這裡的ch陣列才是SAM上的邊

其中\(len\)表示最長的符合\(endpos\)的子串,\(fa\)就直接表示父親節點,\(ch[]\)表示當前節點後面添加了字元\(c\),下一個節點會達到哪裡

我們考慮利用最下面一行的進行構造

設上一次的字首為\(las\)

\(las\)往上的節點的\(endpos\)一定包含\(n-1\),考慮在最後新加一個節點,這些節點如果沒有字元為\(c\)的都必須連一條邊過來

情況1

如果這條鏈上所有的節點都沒有字元\(c\)的兒子,那麼說明新加節點一定會構成一個新的類,故直接連上去就行了,只需要改\(fa\)就行了

情況2

設當前的節點為\(p\),其連出去的邊為\(q\)

如果\(tre[p].len+1==tre[q].len\)

這個時候直接連上去就行了,同樣的,只改\(fa\)

情況3

設當前的節點為\(p\),其連出去的邊為\(q\)

如果\(tre[p].len+1\neq tre[q].len\)

意味著,我們需要將\(q\)這個節點裂開

同時將一些節點的\(ch\)改掉

程式碼

例題傳送門

#include<iostream>
#include<cstring>
#include<cstdio>
#include<queue>
using namespace std;
namespace SAM
{
    struct node
    {
        int fa;
        int len;
        int ch[30];
        node()
        {
            memset(ch,0,sizeof(ch));
        }
    }tre[2000005];
    int las=1,tot=1;
    int dp[200005],d[200005];
    int getf(char c)
    {
        return c-'a'+1;
    }
    void add(int c)
    {  
        int p=las;
        int np=las=++tot;
        dp[tot]=1;
        //cout<<"bas:"<<tot<<'\n';
        tre[np].len=tre[p].len+1;
        for(;p&&!tre[p].ch[c];p=tre[p].fa)
            tre[p].ch[c]=np;
        if(!p)
        {
            tre[np].fa=1;
            return;
        }
        else
        {
            int q=tre[p].ch[c];
            if(tre[q].len==tre[p].len+1)
            {
                tre[np].fa=q;
                return;
            }
            else
            {
                int nq=++tot;
                tre[nq]=tre[q];
                tre[nq].len=tre[p].len+1;
                tre[q].fa=tre[np].fa=nq;
                for(;p&&tre[p].ch[c]==q;p=tre[p].fa)
                    tre[p].ch[c]=nq;
            }
        }
    }
    long long getdp()
    {
        queue<int> q;
        for(int i=1;i<=tot;i++)
            d[tre[i].fa]++;
        for(int i=1;i<=tot;i++)
            if(d[i]==0)
                q.push(i);
        while(!q.empty())
        {
            int t=q.front();
            q.pop();
            int v=tre[t].fa;
            dp[v]+=dp[t];
            d[v]--;
            if(d[v]==0)
                q.push(v); 
        }
        long long ans=0;
        for(int i=1;i<=tot;i++)
            if(dp[i]!=1)
                ans=max(ans,1ll*dp[i]*tre[i].len);
        return ans;
    }
}
using namespace SAM;
char s[1000005];
int lens;
int main()
{
    cin>>(s+1);lens=strlen(s+1);
    for(int i=1;i<=lens;i++)
        add(getf(s[i]));
    cout<<getdp();
    return 0;
}