1. 程式人生 > 實用技巧 >【洛谷P3804】【模板】字尾自動機 (SAM)

【洛谷P3804】【模板】字尾自動機 (SAM)

題目

題目連結:https://www.luogu.com.cn/problem/P3804
給定一個只包含小寫字母的字串\(S\),
請你求出 \(S\) 的所有出現次數不為 \(1\) 的子串的出現次數乘上該子串長度的最大值。
\(|S|\leq 10^6\)

思路

推薦 BlogOIwiki
SAM 可以線上性複雜度內解決一些關於子串的問題。例如經典的本質不同子串個數。
我們定義 \(\mathrm{endpos}(t)\) 表示子串 \(t\) 在原串 \(s\) 中出現位置集合。具體的,

\[\mathrm{endpos}(t)=\{x|s_{x-|t|+1}\cdots s_x=t\} \]

對於 \(\mathrm{endpos}(t)\) 相同的所有子串 \(t\),我們把他們歸到一個等價類中,然後對於每一個等價類建立一個節點。
如果 \(\mathrm{endpos}(t)⫋\mathrm{endpos}(t')\),且不存在任意 \(t''(|t''|>|t'|)\) 滿足 \(\mathrm{endpos}(t)⫋\mathrm{endpos}(t'')\),那麼我們就從 \(t'\) 所在類的節點向 \(t\) 所在類的節點連一條邊。
容易發現,連邊結束後,所有的點構成了一棵樹,我們稱其為 parent 樹。可以證明樹上的點數是 \(O(n)\) 的。
我們記 \(\mathrm{len}(x)\)

表示節點 \(x\) 所在的等價類中,長度最長的子串的長度,\(\mathrm{minlen}(x)\) 表示最短長度。那麼容易發現 \(\mathrm{minlen}(x)=\mathrm{len}(fa_x)+1\)。因為顯然在原串中它們是包含關係且只差一個字元。
一個字尾自動機(SAM)的節點和 parent 樹完全一致,但是連邊方式不同。在 SAM 中,一個類所對應的節點會向另一個類連一條有向邊,當且僅當在這個類的任意子串的末尾新增一個字元 \(c\),得到的串的 \(\mathrm{endpos}\) 集合等於後者所在的類的 \(\mathrm{endpos}\) 集合。那麼這條有向邊的權值就為這個字元 \(c\)

依舊是可以證明,SAM 的邊數是 \(O(n)\) 的。具體可以看上面推薦的文章。wtcl。
然後利用 parent 樹和字尾自動機,就可以解決字串的很多問題。但是為了保證複雜度,如果需要排序,那麼要採用基數排序。


回到本題,建出 parent 樹後,容易發現一個點所表示的等價類中,所有子串出現的次數都等於 parent 樹上這個節點的子樹大小。
那麼直接求出每一個節點的大小,乘上 \(\mathrm{len_x}\) 取最大值即可。
時間複雜度 \(O(n)\)

程式碼

#include <bits/stdc++.h>
using namespace std;
typedef long long ll;

const int N=2000010;
int n,a[N],cnt[N];
ll ans,f[N];
char s[N];

struct SAM
{
	int tot,last,ch[N][26],fa[N],len[N];
	SAM() { tot=last=1; }

	void ins(int c)
	{
		int np=++tot,p=last; 
		len[np]=len[p]+1; last=np; f[np]++;
		for (;p && !ch[p][c];p=fa[p]) ch[p][c]=np;
		if (!p) fa[np]=1;
		else
		{
			int q=ch[p][c];
			if (len[q]==len[p]+1) fa[np]=q;
			else
			{
				int nq=++tot;
				len[nq]=len[p]+1; fa[nq]=fa[q];
				for (int i=0;i<26;i++) ch[nq][i]=ch[q][i];
				fa[q]=fa[np]=nq;
				for (;p && ch[p][c]==q;p=fa[p]) ch[p][c]=nq;
			}
		}
	}
	
	void topsort()
	{
		for (int i=1;i<=tot;i++) cnt[len[i]]++;
		for (int i=1;i<=tot;i++) cnt[i]+=cnt[i-1];
		for (int i=tot;i>=1;i--) a[cnt[len[i]]--]=i;
		for (int i=tot;i>=1;i--)
		{
			if (f[a[i]]>1) ans=max(ans,f[a[i]]*len[a[i]]);
			f[fa[a[i]]]+=f[a[i]];
		}
	}
}sam;

int main()
{
	scanf("%s",s+1);
	n=strlen(s+1);
	for (int i=1;i<=n;i++)
		sam.ins(s[i]-'a');
	sam.topsort();
	printf("%lld\n",ans);
	return 0;
}