hdu6153字尾陣列或擴充套件KMP
前兩天刷了幾題leetcode,感覺挺簡單,於是又想刷刷hduoj了。隨便開啟沒做過的一頁,找了一題通過人數最多的,就是這道6153.
①.看完題沒想太多,覺得應該是字尾陣列(多年沒刷題的我字串這一塊對字尾陣列記憶最深吧),因為S1和S2長度都一百萬,n^2受不了。nlogn應該行。
②.用字尾陣列的話,需要會用字尾陣列求子串出現次數。如果是任意子串,還不太好辦,但是這裡的子串只會是字尾,那就好辦了——只需看此後綴在所有後綴中的字首的數量,也即與本字尾的LCP等於len(s)的字尾數量。因為字尾陣列是排過序的,只要往後看,直到height值小於len(s)為止。
③.會了②,就好辦了。把S1和S2連起來(中間插入一個比如~號,記為S),求出S的字尾陣列SA;又對S2求出字尾陣列SA2。
對於每一個S2的字尾,用在SA中求出的出現次數減去SA2中求出的出現次數,就是此後綴在S1中的出現次數(因為它不會在跨S1和S2時出現)。
④.複雜度分析。構造字尾陣列最優演算法是O(n),用倍增演算法構造字尾陣列是O(nlogn)。後面的計算步驟,因為S2長度n為一百萬,要計算n次,每次用②中的辦法往後找,最壞情況要O(n)。但是不意味著這裡的總複雜度只能O(n^2)。因為字尾陣列是排好序的,按字尾陣列的逆序處理這n次查詢,那麼第k次查詢可以充分利用第k-1次查詢的結果,往後滑動。也就是說這n次查詢基本不會重疊,所以最後查詢計算部分的總複雜度仍為O(n)。
這個地方我舉個例子,以字串aabaaaab~aabaaaab為例,它有17個字尾,在後綴陣列中的順序為:
aaaab
aaaab~aabaaaab
aaab
aaab~aabaaaab
aab
aabaaaab
aabaaaab~aabaaaab
aab~aabaaaab
ab
abaaaab
abaaaab~aabaaaab
ab~aabaaaab
b
baaaab
baaaab~aabaaaab
b~aabaaaab
~aabaaaab
求出字尾baaaab的出現次數為2以後,再求字尾b的出現次數時,因為知道b是baaaab的子串(height值等於b的長度),可以直接滑到b~aabaaaab進行判斷。
那最後整個演算法的複雜度還是卡在構造字尾陣列部分,如果用倍增,那整個演算法的複雜度為O(nlogn)
⑤.具體實現上,因為我要複用兩次後綴陣列的程式碼,所以每次預處理出data陣列,再處理出d陣列。
d陣列就是逆序存的每個字尾在整個串中出現的次數。
程式碼如下:
1 /* 2 * Author : ben 3 */ 4 #include <cstdio> 5 #include <cstdlib> 6 #include <cstring> 7 typedef long long LL; 8 const int MAXN = 2010000; 9 char s[MAXN]; 10 int sa[MAXN], height[MAXN], rank[MAXN], N; 11 int tmp[MAXN], top[MAXN]; 12 void makesa() { 13 inti, j, len, na; 14 na = (N < 256 ? 256 : N); 15 memset(top, 0, na * sizeof(int)); 16 for (i = 0; i < N; i++) { 17 top[rank[i] = s[i] & 0xff]++; 18 } 19 for (i = 1; i < na; i++) { 20 top[i] += top[i - 1]; 21 } 22 for (i = 0; i < N; i++) {23 sa[--top[rank[i]]] = i; 24 } 25 for (len = 1; len < N; len <<= 1) { 26 for (i = 0; i < N; i++) { 27 j = sa[i] - len; 28 if (j < 0) { 29 j += N; 30 } 31 tmp[top[rank[j]]++] = j; 32 } 33 sa[tmp[top[0] = 0]] = j = 0; 34 for (i = 1; i < N; i++) { 35 if (rank[tmp[i]] != rank[tmp[i - 1]] 36 || rank[tmp[i] + len] != rank[tmp[i - 1] + len]) { 37 top[++j] = i; 38 } 39 sa[tmp[i]] = j; 40 } 41 memcpy(rank, sa, N * sizeof(int)); 42 memcpy(sa, tmp, N * sizeof(int)); 43 if (j >= N - 1) { 44 break; 45 } 46 } 47 } 48 49 void lcp() { 50 int i, j, k; 51 for (j = rank[height[i = k = 0] = 0]; i < N - 1; i++, k++) { 52 while (k >= 0 && s[i] != s[sa[j - 1] + k]) { 53 height[j] = (k--), j = rank[sa[j] + 1]; 54 } 55 } 56 } 57 58 char S1[MAXN], S2[MAXN]; 59 int data1[MAXN], data2[MAXN]; 60 int d[MAXN]; 61 62 void makedata(int *data) { 63 data[0] = 1; 64 for (int i = N - 2; i > 0; i--) { 65 int leni = N - 1 - sa[i]; 66 int j = N - i - 1; 67 if (height[i + 1] < leni) { 68 data[j] = 1; 69 } else { 70 int k = i + data[j - 1] + 1; 71 while (k < N && height[k] >= leni) { 72 k++; 73 } 74 data[j] = k - i; 75 } 76 // cout << data[j] << endl; 77 } 78 } 79 80 const LL MOD_NUM = 1000000007LL; 81 int work(int lens1, int lens2) { 82 int ans = 0; 83 strcpy(s, S2); 84 N = lens2 + 1; 85 makesa(); 86 lcp(); 87 makedata(data1); 88 for (int i = 0; i < lens2; i++) { 89 int j = lens2 - i - 1; 90 d[i] = data1[lens2 - rank[j]]; 91 // cout << d[i] << endl; 92 } 93 strcpy(s, S1); 94 s[lens1] = '#'; 95 s[lens1 + 1] = 0; 96 strcat(s, S2); 97 N = lens1 + lens2 + 2; 98 makesa(); 99 lcp(); 100 makedata(data2); 101 for (int i = 0; i < lens2; i++) { 102 int j = N - i - 2; 103 d[i] = data2[N - 1 - rank[j]] - d[i]; 104 ans = ans % MOD_NUM + ((i + 1LL) * d[i]) % MOD_NUM; 105 // cout << d[i] << endl; 106 } 107 return ans % MOD_NUM; 108 } 109 110 int main() { 111 int T; 112 scanf("%d", &T); 113 while (T--) { 114 scanf("%s%s", S1, S2); 115 int lens1 = strlen(S1); 116 int lens2 = strlen(S2); 117 printf("%d\n", work(lens1, lens2)); 118 } 119 return 0; 120 }
然而,提交上去,超時了。看來資料量很強。是卡在後綴陣列構造的倍增演算法上了。但是我手頭沒有DC3等更優演算法的模板。
轉念一想,這題通過的人這麼多,不可能需要高階字尾陣列演算法的。於是回憶了一下還有別的什麼演算法。想到了擴充套件KMP。
用擴充套件KMP解此題的思路主要是要把S1和S2逆序,逆序以後,題目要求的S2的每個字尾就變成字首了。根據extend陣列的定義我們知道,如果extend[i] = x,則表示S2的前x個字元與S1從i開始的x個字元相同,統計extend陣列中有多少個x就知道S2的這個字首在S1中出現的次數。這種統計是可以線上性時間完成的,而前面生成extend陣列的時間也為線性,故最後整體複雜度也為線性O(N)。程式碼如下
1 /* 2 * Author : ben 3 */ 4 #include <cstdio> 5 #include <cstdlib> 6 #include <cstring> 7 #include <cmath> 8 #include <ctime> 9 #include <algorithm> 10 typedef long long LL; 11 const int MAXN = 1001000; 12 char S1[MAXN], S2[MAXN]; 13 int d[MAXN]; 14 const LL MOD_NUM = 1000000007LL; 15 int next[MAXN], extend[MAXN]; 16 void get_next(const char *str, int len){ 17 // 計算next[0]和next[1] 18 next[0] = len; 19 int i = 0; 20 while(str[i] == str[i + 1] && i + 1 < len) { 21 i++; 22 } 23 next[1] = i; 24 int po = 1; //初始化po的位置 25 for(i = 2; i < len; i++) { 26 if(next[i - po] + i < next[po] + po) { //第一種情況,可以直接得到next[i]的值 27 next[i] = next[i - po]; 28 } else { //第二種情況,要繼續匹配才能得到next[i]的值 29 int j = next[po] + po - i; 30 if(j < 0) { 31 j = 0; //如果i > po + next[po],則要從頭開始匹配 32 } 33 while(i + j < len && str[j] == str[j + i]) { //計算next[i] 34 j++; 35 } 36 next[i] = j; 37 po = i; //更新po的位置 38 } 39 } 40 } 41 void extend_KMP(const char *str, int lens, const char *pattern, int lenp) { 42 get_next(pattern, lenp); // 先計算模式串的next陣列 43 // 計算extend[0] 44 int i = 0; 45 while(str[i] == pattern[i] && i < lenp && i < lens) { 46 i++; 47 } 48 extend[0] = i; 49 int po = 0; // 初始化po的位置 50 for(i = 1; i < lens; i++) { 51 if(next[i - po] + i < extend[po] + po) { //第一種情況,直接可以得到extend[i]的值 52 extend[i] = next[i - po]; 53 } else { // 第二種情況,要繼續匹配才能得到extend[i]的值 54 int j = extend[po] + po - i; 55 if(j < 0) { 56 j = 0; //如果i > extend[po] + po則要從頭開始匹配 57 } 58 while(i + j < lens && j < lenp && str[j + i] == pattern[j]) { // 計算extend[i] 59 j++; 60 } 61 extend[i] = j; 62 po = i; // 更新po的位置 63 } 64 } 65 } 66 67 int main() { 68 int T; 69 scanf("%d", &T); 70 while (T--) { 71 scanf("%s%s", S1, S2); 72 int len1 = strlen(S1); 73 int len2 = strlen(S2); 74 std::reverse(S1, S1 + len1); 75 std::reverse(S2, S2 + len2); 76 extend_KMP(S1, len1, S2, len2); 77 memset(d, 0, sizeof(d)); 78 for (int i = 0; i < len1; i++) { 79 // printf("%d ", extend[i]); 80 d[extend[i]]++; 81 } 82 LL total = 0LL; 83 int ans = 0; 84 for (int j = len2; j > 0; j--) { 85 total = (total + d[j]) % MOD_NUM; 86 ans = (ans + total * j) % MOD_NUM; 87 } 88 printf("%d\n", ans); 89 // putchar('\n'); 90 } 91 return 0; 92 }
最後,在做此題的過程中我還突然發現輸入外掛沒用了,加上輸入外掛後的執行時間比直接用scanf更長。也許是因為現在的oj用上了最新的編譯器吧。