2020牛客暑期多校訓練營(第八場) H Hard String Problem
阿新 • • 發佈:2020-08-04
題解:\(bin\) 巨的題解其實已經很詳細了,這裡講幾個可能會踩到的坑,\(kmp\) 求最小迴圈節時我們要先把原來的串翻倍,也就是在串後面再接上這個串本身,因為 \(kmp\) 中的 \(len - next[len]\) 所得出來的迴圈節長度是可能的最小迴圈節長度,但我們要求的是準確的,比如 \(abcdabc\) ,如果直接用 \(kmp\) 求的話,得到的最小迴圈節會是 \(abcd\) , 展開後為 \(abcdabcdabcd\) , 而原串展開後為 \(abcdabcabcdabc\) ,有明顯的不同,然後我們將每個迴圈節最小表示,判斷迴圈節是否相同,之後就是按 \(bin\)
#include<bits/stdc++.h> using namespace std; typedef long long LL; const int maxn = 6e6 + 50; const int maxn2 = 3e5 + 50; string ss[maxn2], s; int knex[maxn]; int n; void getNext(int id){ int tlen = ss[id].size(); int j = 0, k = -1; knex[0] = -1; while(j < tlen){ if(k == -1 || ss[id][j] == ss[id][k]) knex[++j] = ++k; else k = knex[k]; } } int get_min(int id){ int len = ss[id].size(); int i = 0, j = 1, k = 0, t; while(i < len && j < len && k < len){ t = ss[id][(i + k) % len] - ss[id][(j + k) % len]; if(!t) k++; else { if(t > 0) i += k + 1; else j += k + 1; if(i == j) j++; k = 0; } } return min(i, j); } struct state { int len, link, nex[26]; } st[maxn]; int sz, last; void sam_init(){ st[0].len = 0; st[0].link = -1; sz = 1, last = 0; } void sam_extend(int x){ int cur = sz++; st[cur].len = st[last].len + 1; int p = last; while(p != -1 && !st[p].nex[x]){ st[p].nex[x] = cur; p = st[p].link; } if(p == -1) st[cur].link = 0; else { int q = st[p].nex[x]; if(st[p].len + 1 == st[q].len){ st[cur].link = q; } else { int clone = sz++; st[clone].len = st[p].len + 1; st[clone].link = st[q].link; for(int i = 0; i < 26; i++){st[clone].nex[i] = st[q].nex[i];} while(p != -1 && st[p].nex[x] == q){ st[p].nex[x] = clone; p = st[p].link; } st[q].link = st[cur].link = clone; } } last = cur; } LL val[maxn]; bool cmp(const string s1, const string s2){ return s1.size() < s2.size(); } int vis[maxn]; int main() { cin >> n; for(int i = 1; i <= n; i++){ cin >> ss[i]; ss[i] += ss[i]; getNext(i); int len = ss[i].size(); len = len - knex[len]; ss[i] = ss[i].substr(0, len); } for(int i = 1; i <= n; i++){ int st = get_min(i); int len = ss[i].size(); s = ""; for(int j = 0; j < len; j++){ s += ss[i][(st + j) % len]; } ss[i] = s; } int flag = 1; for(int i = 2; i <= n; i++){ if(ss[i] != ss[i - 1]) { flag = 0; break; } } if(flag){ cout << -1 << '\n'; return 0; } sort(ss + 1, ss + n + 1, cmp); for(int i = 2; i <= n; i++){ s = ss[i]; for(int j = 1; j <= 3; j++) ss[i] += s; } s = ss[1]; while(ss[1].size() < ss[n].size()){ ss[1] += s; } sam_init(); for(int i = 1; i <= n; i++){ int len = ss[i].size(); last = 0; for(int j = 0; j < len; j++){ sam_extend(ss[i][j] - 'a'); } } for(int i = 1; i <= n; i++){ int len = ss[i].size(); int p = 0; for(int j = 0; j < len; j++){ p = st[p].nex[ss[i][j] - 'a']; int u = p; while(vis[u] != i && u != 0) vis[u] = i, val[u]++, u = st[u].link; } } LL ans = 0; for(int i = 1; i < sz; i++){ if(val[i] == n) ans += st[i].len - st[st[i].link].len; } cout << ans << '\n'; }