718. Maximum Length of Repeated Subarray 字尾陣列解最長公共子串 O(n log^2 n)時間複雜度
阿新 • • 發佈:2018-11-11
題意
- 找最長公共子串
思路
- 用dp的方法很容易在O(n^2)解決問題,這裡主要討論用字尾陣列的思路解決這個問題
- 字尾數組裡有兩個經典的概念或者稱為資料結構,就是字尾陣列SA,以及高度陣列LCP
- SA陣列的定義是:將原串S所有的字尾按字典序排序後,定義rank(i)為字尾S[i…]的排名,SA陣列是rank陣列的逆對映,即SA(rank(i)) = i
- LCP陣列的定義是:LCP(i)是字尾S[SA(i)…]和字尾S[SA(i+1)…]的最長公共字首長度
- 這裡就不討論這兩個陣列的求解演算法了,我們使用比較簡單的倍增法求解SA陣列,複雜度是O(n log^2 n)的,有了SA陣列,求解LCP陣列是O(n)的
- 有了LCP陣列後,我們先來思考另一個問題,一個數組裡兩個不同的子串的最長公共子串是多長呢?答案是max(LCP),也就是LCP數組裡的最大值。原因的話反證一下很簡單,這裡簡單說明一下,主要是考慮陣列其它子串和以i開頭的子串的最長公共子串是多長,容易證明能達到最長的只能是以SA(rank(i)-1)或SA(rank(i)+1)開頭的子串,那麼這個結果都儲存在LCP裡了,所以遍歷一遍LCP就能找到最大值
- 利用上述結論,我們很容易解決新的問題了。可以把兩個陣列拼在一起,並在拼接處加一個特殊的int,是在兩個數組裡都沒有出現的
- 求出LCP陣列後,我們只要找i和SA(i+1)不在同一個字串的LCP的最大值即可
實現
class Solution {
public:
//size of rank and sa are n+1
//size of lcp is n, definition of lcp[i] is max common prefix of s[sa[i]...] and s[sa[i+1]...]
//input s of getSa and getLcp can be string as well
vector<int> rank, sa, lcp;
void getSa(const vector<int>& s, vector<int>& rank, vector<int>& sa){
int n = s.size();
vector<int> tmp(n+1);
for (int i = 0; i < n; i++){
sa.push_back(i);
rank.push_back(s[i]);
}
sa.push_back(n);
rank.push_back(-1);
for (int k = 1; k <= n; k <<= 1){
auto cmp = [&](int x, int y){
if (rank[x] != rank[y])
return rank[x] < rank[y];
auto tx = x + k > n ? -1 : rank[x + k];
auto ty = y + k > n ? -1 : rank[y + k];
return tx < ty;
};
sort(sa.begin(), sa.end(), cmp);
tmp[sa[0]] = 0;
for (int i = 1; i <= n; i++){
tmp[sa[i]] = tmp[sa[i-1]];
if (cmp(sa[i-1], sa[i])){
tmp[sa[i]]++;
}
}
for (int i = 0; i <= n; i++)
rank[i] = tmp[i];
}
}
void getLcp(const vector<int>& s, const vector<int>& rank, const vector<int>& sa,
vector<int>& lcp){
int n = s.size();
lcp.insert(lcp.begin(), n, 0);
for (int i = 0, h = 0; i < n; i++){
if (h > 0)
h--;
int k = rank[i];
int j = sa[k-1];
while (max(j, i) + h < n && s[j+h] == s[i+h]){
h++;
}
lcp[k-1] = h;
}
}
int findLength(vector<int>& A, vector<int>& B) {
int n = A.size(), m = B.size();
A.push_back(101);
A.insert(A.end(), B.begin(), B.end());
getSa(A, rank, sa);
getLcp(A, rank, sa, lcp);
int ans = 0;
for (int i = 0; i <= n + m; i++){
if (sa[i] < n && sa[i+1] > n || sa[i] > n && sa[i+1] < n){
ans = max(ans, lcp[i]);
}
}
return ans;
}
};