關於【AC自動姬】的學習
突然發現自己還不會AC自動姬,就去學了一下下。。。
自動姬de用途:
給你n個模式串,1個文字串,求出有多少模式串在文字串中出現過。(詳見洛谷P3808)
大體思路:
我先前貌似介紹過一種叫做KMP的神奇東西,還有叫做trie樹的逆天玩意兒,嗯,今天要講的AC自動姬就是兩者相結合,誕生的產品啦。
(圖片來自百度)
嗯,看到上面的圖片,讓我們手動過濾掉那些虛線,剩下的就是一顆加入了he,she,his,hers,四個字串的trie樹了。
基礎trie樹的構建(程式碼):
void insert(char *s){ int now=0; for (int i=0;s[i];i++){if (!trie[now][s[i]-'a']) trie[now][s[i]-'a']=++tot; now=trie[now][s[i]-'a']; } end[now]++; }
Fail指標:
那麼那些虛線又是什麼鬼呢?
其實,它就是自動姬的精髓所在,我們稱它為Fail指標(也叫失配指標)。
我覺得失配指標這個名字取的特別形象,因為fail指標的作用就是在文字串與當前串失去匹配後的退路。
什麼意思呢?
比如說:文字串在she的h處與模式串失去了匹配(也就是說文字串的下一個字元不是‘e’),也就是說,我們現在被字串“she”給炒魷魚了,那麼我們應該怎麼做呢,無家可歸了嗎?
不要慌——既然“she”不要你了,但是我們“his”還是要你的,畢竟你有字元‘h’在,和我們‘his’的前一個字元是相符合的,你可以來我們這裡試試看,假如說合適的話,就留下好了。
這樣就很容易理解了吧!
我們fail指標的作用就是為了找到這樣一條“後路”。
那麼假如說我們真的淪落到了沒人要的地步,也不用擔心,我們“家裡蹲”(家,是最溫暖的港灣)還是會收留你的,嗯,沒錯,真的沒人要的話,只要把fail指標指向根節點就行了。
我們可以採用BFS的方式來求出這個fail指標。
求出當前節點的fail指標的條件是:我們已經求出了之前所有節點的fail指標。 這裡的“之前”指的就是所有深度小於當前節點的點,所以才要用BFS來實現嘛。
為什麼會有這個前置條件呢,換句話說,為什麼我們知道了這些就能求出當前節點的fail指標呢?
這還得從fail指標的定義出發——失配後的退路,也就是說,我們要找到當前串的部分字尾與其它模式串的字首完全相同的節點,fail指標便是指向這個節點的。
所以,當前的fail指標肯定會指向深度小於等於當前節點的節點。
具體我們可以怎麼做呢,其實很簡單。
因為要和當前串的部分字尾完全相同,說明肯定要和前一個字元為止的部分字尾完全相同嘛,所以我們就先要找到fail[s[i-1]]指向的節點,假如說它有s[i]這個兒子的話,那麼fail指標就指向它的這個兒子了,如果沒有的話,我們就要順著這個節點的fail指標接著向上找,一直到找到或者到達根節點為止。
我們會很清楚的感覺到:在順著節點的fail指標不斷往上找的過程中,與當前串的字尾完全相同的字首的部分的長度是在不斷減小的。(貌似很拗口,但是如果你有這種感覺的話,就說明你對AC自動姬的學習已經有點感覺了)
Fail指標的構建(程式碼):
void build(){ queue <int> q; for (int i=0;i<26;i++) if (trie[0][i]) q.push(trie[0][i]); while (!q.empty()){ int now=q.front(); q.pop(); for (int i=0;i<26;i++) if (trie[now][i]) fail[trie[now][i]]=trie[fail[now]][i],q.push(trie[now][i]); else trie[now][i]=trie[fail[now]][i];//注意了,這行程式碼非常精髓,把原本我們可能要跳多次的環節,直接省略到了一次,因為這句話在執行一個“路徑壓縮”的操作,把一些沒用的(跳了之後發現沒有想要兒子的)跳躍全部都省去了 } }
與文字串的匹配:
有了構建Fail指標時的路徑壓縮之後,查詢操作就顯的簡單多了,因為路徑壓縮後,我們把一些字首與當前串的部分字尾相同的節點都連線到了這個節點下方,比如“she”和“her”,以為she的字尾he是與her的字首he完全相同的,所以路徑壓縮操作會給當前節點(當前節點是指通過she到達的節點)新加入一條邊‘r’,連到沿著he走到達的節點。所以我們就可以實現在自動姬上反覆橫跳啦。
查詢操作(程式碼):
int query(char *s){ int now=0,res=0; for (int i=0;s[i];i++){ now=trie[now][s[i]-'a']; for (int j=now;j&&~end[j];j=fail[j]) res+=end[j],end[j]=-1; } return res; }
完整程式碼:
#include <bits/stdc++.h> using namespace std; const int maxn=1000005; char s[maxn]; int trie[maxn][30],fail[maxn],end[maxn],n,tot; void insert(char *s){ int now=0; for (int i=0;s[i];i++){ if (!trie[now][s[i]-'a']) trie[now][s[i]-'a']=++tot; now=trie[now][s[i]-'a']; } end[now]++; } void build(){ queue <int> q; for (int i=0;i<26;i++) if (trie[0][i]) q.push(trie[0][i]); while (!q.empty()){ int now=q.front(); q.pop(); for (int i=0;i<26;i++) if (trie[now][i]) fail[trie[now][i]]=trie[fail[now]][i],q.push(trie[now][i]); else trie[now][i]=trie[fail[now]][i]; } } int query(char *s){ int now=0,res=0; for (int i=0;s[i];i++){ now=trie[now][s[i]-'a']; for (int j=now;j&&~end[j];j=fail[j]) res+=end[j],end[j]=-1; } return res; } int main(){ scanf("%d",&n); for (int i=1;i<=n;i++) scanf("%s",s),insert(s); build(); scanf("%s",s); printf("%d\n",query(s)); return 0; }