1. 程式人生 > 實用技巧 >Codeforces 235C Cyclical Quest (字尾自動機)

Codeforces 235C Cyclical Quest (字尾自動機)

思路:一眼看過去,好像處理出每個字串的最小表示的 \(hash\) 值就可以解決了, 但想了複雜度明顯過不去,由於要統計某種子串個數,所以首先想到字尾自動機,然後分析,我們將每次查詢的模式串翻倍(接在自身後面),模式串的原本長度為 \(n\) ,假設我們現在在後綴自動機上找到了區間 \((le, ri)\) 的子串,首先判斷 \(ri - le + 1\) 是否等於 \(n\) , 若相等則加上該節點 \(ednpoints\) 集合大小,然後我們要查詢的就是 $(le + 1, ri + 1) $ 的子串了,首先看子串 \((le + 1, ri)\) 是否屬於該節點,若不屬於,則沿著 \(link\)

連結向上跳,跳到包含子串 \((le + 1, ri)\) 的節點 \(p\) ,然後判斷 \(st[p].next[s[ri + 1]]\) 是否存在,若存在,則 \(p\) 跳到 \(p = st[p].next[s[ri + 1]]\) ,否則 \(p\) 直接跳到 \(st[p].link\) , 並更新對應的 \(le\) 。具體看程式碼

#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
const int maxn = 1e5 + 50;
struct state {
  int len, link;
  int next[26];
};

state st[maxn * 20];
int sz, last;

void sam_init() {
  st[0].len = 0;
  st[0].link = -1;
  sz = 1;
  last = 0;
}

LL num[maxn * 20];
void sam_extend(int c) {
  int cur = sz++;
  st[cur].len = st[last].len + 1;
  int p = last;
  while (p != -1 && !st[p].next[c]) {
    st[p].next[c] = cur;
    p = st[p].link;
  }
  if (p == -1) {
    st[cur].link = 0;
  } else {
    int q = st[p].next[c];
    if (st[p].len + 1 == st[q].len) {
      st[cur].link = q;
    } else {
      int clone = sz++;
      st[clone].len = st[p].len + 1;
      for(int i = 0; i < 26; i++) st[clone].next[i] = st[q].next[i];
      st[clone].link = st[q].link;
      while (p != -1 && st[p].next[c] == q) {
        st[p].next[c] = clone;
        p = st[p].link;
      }
      st[q].link = st[cur].link = clone;
    }
  }
  last = cur;
}

struct Edge
{
	int to, next;
} edge[maxn * 40];

int k, head[maxn * 20];
void add(int a, int b){
	edge[k].to = b;
	edge[k].next = head[a];
	head[a] = k++;
}

void dfs(int u, int pre){
	for(int i = head[u]; i != -1; i = edge[i].next){
		int to = edge[i].to;
		if(to == pre) continue;
		dfs(to, u);
		num[u] += num[to];
	}
}
string s, t;
int vis[maxn * 20];
int main(int argc, char const *argv[])
{
	cin >> t;
	int tlen = t.size();
	sam_init();
	for(int i = 0; i < tlen; i++){
		sam_extend(t[i] - 'a');
		num[last] = 1;
	}
	for(int i = 0; i < sz; i++) head[i] = -1;
	for(int i = 1; i < sz; i++){
		add(i, st[i].link);
		add(st[i].link, i);
	}

	dfs(0, -1);
	int q;
	scanf("%d", &q);
	int id = 0;
	while(q--){
		id++;
		cin >> s;
		int n = s.size();
		s += s;
		int p = 0;
		int le = 0, ri = 0;
		LL ans = 0;
		while(le < n && ri < 2 * n){
			if(st[p].next[s[ri] - 'a']){
				p = st[p].next[s[ri] - 'a'];
				if(ri - le + 1 == n){
					if(vis[p] != id){ // 記錄一下該點的貢獻已經加過,防止重複算貢獻,比如第二個樣例
						vis[p] = id;
						ans += num[p];
					}
					le++;
					while(st[st[p].link].len + 1 > n - 1 && p != 0){
						p = st[p].link;
					}
				}
				ri++;
			} else {
				if(p == 0) le++, ri = le; // 注意,若 p 是節點 0 ,則需要讓 le++, 否則會死迴圈
				p = st[p].link; 
				le = ri - 1 - st[p].len + 1;
			}
		}
		printf("%lld\n", ans);
	}
	return 0;
}