1. 程式人生 > 其它 >hdu6153字尾陣列或擴充套件KMP

hdu6153字尾陣列或擴充套件KMP

前兩天刷了幾題leetcode,感覺挺簡單,於是又想刷刷hduoj了。隨便開啟沒做過的一頁,找了一題通過人數最多的,就是這道6153.

①.看完題沒想太多,覺得應該是字尾陣列(多年沒刷題的我字串這一塊對字尾陣列記憶最深吧),因為S1和S2長度都一百萬,n^2受不了。nlogn應該行。
②.用字尾陣列的話,需要會用字尾陣列求子串出現次數。如果是任意子串,還不太好辦,但是這裡的子串只會是字尾,那就好辦了——只需看此後綴在所有後綴中的字首的數量,也即與本字尾的LCP等於len(s)的字尾數量。因為字尾陣列是排過序的,只要往後看,直到height值小於len(s)為止。
③.會了②,就好辦了。把S1和S2連起來(中間插入一個比如~號,記為S),求出S的字尾陣列SA;又對S2求出字尾陣列SA2。
對於每一個S2的字尾,用在SA中求出的出現次數減去SA2中求出的出現次數,就是此後綴在S1中的出現次數(因為它不會在跨S1和S2時出現)。
④.複雜度分析。構造字尾陣列最優演算法是O(n),用倍增演算法構造字尾陣列是O(nlogn)。後面的計算步驟,因為S2長度n為一百萬,要計算n次,每次用②中的辦法往後找,最壞情況要O(n)。但是不意味著這裡的總複雜度只能O(n^2)。因為字尾陣列是排好序的,按字尾陣列的逆序處理這n次查詢,那麼第k次查詢可以充分利用第k-1次查詢的結果,往後滑動。也就是說這n次查詢基本不會重疊,所以最後查詢計算部分的總複雜度仍為O(n)。
這個地方我舉個例子,以字串aabaaaab~aabaaaab為例,它有17個字尾,在後綴陣列中的順序為:
aaaab
aaaab~aabaaaab
aaab
aaab~aabaaaab
aab
aabaaaab
aabaaaab~aabaaaab
aab~aabaaaab
ab
abaaaab
abaaaab~aabaaaab
ab~aabaaaab
b
baaaab
baaaab~aabaaaab
b~aabaaaab
~aabaaaab
求出字尾baaaab的出現次數為2以後,再求字尾b的出現次數時,因為知道b是baaaab的子串(height值等於b的長度),可以直接滑到b~aabaaaab進行判斷。
那最後整個演算法的複雜度還是卡在構造字尾陣列部分,如果用倍增,那整個演算法的複雜度為O(nlogn)
⑤.具體實現上,因為我要複用兩次後綴陣列的程式碼,所以每次預處理出data陣列,再處理出d陣列。
d陣列就是逆序存的每個字尾在整個串中出現的次數。
程式碼如下:

  1 /*
  2  * Author    : ben
  3  */
  4 #include <cstdio>
  5 #include <cstdlib>
  6 #include <cstring>
  7 typedef long long LL;
  8 const int MAXN = 2010000;
  9 char s[MAXN];
 10 int sa[MAXN], height[MAXN], rank[MAXN], N;
 11 int tmp[MAXN], top[MAXN];
 12 void makesa() {
 13     int
i, j, len, na; 14 na = (N < 256 ? 256 : N); 15 memset(top, 0, na * sizeof(int)); 16 for (i = 0; i < N; i++) { 17 top[rank[i] = s[i] & 0xff]++; 18 } 19 for (i = 1; i < na; i++) { 20 top[i] += top[i - 1]; 21 } 22 for (i = 0; i < N; i++) {
23 sa[--top[rank[i]]] = i; 24 } 25 for (len = 1; len < N; len <<= 1) { 26 for (i = 0; i < N; i++) { 27 j = sa[i] - len; 28 if (j < 0) { 29 j += N; 30 } 31 tmp[top[rank[j]]++] = j; 32 } 33 sa[tmp[top[0] = 0]] = j = 0; 34 for (i = 1; i < N; i++) { 35 if (rank[tmp[i]] != rank[tmp[i - 1]] 36 || rank[tmp[i] + len] != rank[tmp[i - 1] + len]) { 37 top[++j] = i; 38 } 39 sa[tmp[i]] = j; 40 } 41 memcpy(rank, sa, N * sizeof(int)); 42 memcpy(sa, tmp, N * sizeof(int)); 43 if (j >= N - 1) { 44 break; 45 } 46 } 47 } 48 49 void lcp() { 50 int i, j, k; 51 for (j = rank[height[i = k = 0] = 0]; i < N - 1; i++, k++) { 52 while (k >= 0 && s[i] != s[sa[j - 1] + k]) { 53 height[j] = (k--), j = rank[sa[j] + 1]; 54 } 55 } 56 } 57 58 char S1[MAXN], S2[MAXN]; 59 int data1[MAXN], data2[MAXN]; 60 int d[MAXN]; 61 62 void makedata(int *data) { 63 data[0] = 1; 64 for (int i = N - 2; i > 0; i--) { 65 int leni = N - 1 - sa[i]; 66 int j = N - i - 1; 67 if (height[i + 1] < leni) { 68 data[j] = 1; 69 } else { 70 int k = i + data[j - 1] + 1; 71 while (k < N && height[k] >= leni) { 72 k++; 73 } 74 data[j] = k - i; 75 } 76 // cout << data[j] << endl; 77 } 78 } 79 80 const LL MOD_NUM = 1000000007LL; 81 int work(int lens1, int lens2) { 82 int ans = 0; 83 strcpy(s, S2); 84 N = lens2 + 1; 85 makesa(); 86 lcp(); 87 makedata(data1); 88 for (int i = 0; i < lens2; i++) { 89 int j = lens2 - i - 1; 90 d[i] = data1[lens2 - rank[j]]; 91 // cout << d[i] << endl; 92 } 93 strcpy(s, S1); 94 s[lens1] = '#'; 95 s[lens1 + 1] = 0; 96 strcat(s, S2); 97 N = lens1 + lens2 + 2; 98 makesa(); 99 lcp(); 100 makedata(data2); 101 for (int i = 0; i < lens2; i++) { 102 int j = N - i - 2; 103 d[i] = data2[N - 1 - rank[j]] - d[i]; 104 ans = ans % MOD_NUM + ((i + 1LL) * d[i]) % MOD_NUM; 105 // cout << d[i] << endl; 106 } 107 return ans % MOD_NUM; 108 } 109 110 int main() { 111 int T; 112 scanf("%d", &T); 113 while (T--) { 114 scanf("%s%s", S1, S2); 115 int lens1 = strlen(S1); 116 int lens2 = strlen(S2); 117 printf("%d\n", work(lens1, lens2)); 118 } 119 return 0; 120 }

然而,提交上去,超時了看來資料量很強。是卡在後綴陣列構造的倍增演算法上了但是我手頭沒有DC3等更優演算法的模板。

轉念一想,這題通過的人這麼多,不可能需要高階字尾陣列演算法的。於是回憶了一下還有別的什麼演算法。想到了擴充套件KMP。

用擴充套件KMP解此題的思路主要是要把S1和S2逆序,逆序以後,題目要求的S2的每個字尾就變成字首了。根據extend陣列的定義我們知道,如果extend[i] = x,則表示S2的前x個字元與S1從i開始的x個字元相同,統計extend陣列中有多少個x就知道S2的這個字首在S1中出現的次數。這種統計是可以線上性時間完成的,而前面生成extend陣列的時間也為線性,故最後整體複雜度也為線性O(N)。程式碼如下

 1 /*
 2  * Author    : ben
 3  */
 4 #include <cstdio>
 5 #include <cstdlib>
 6 #include <cstring>
 7 #include <cmath>
 8 #include <ctime>
 9 #include <algorithm>
10 typedef long long LL;
11 const int MAXN = 1001000;
12 char S1[MAXN], S2[MAXN];
13 int d[MAXN];
14 const LL MOD_NUM = 1000000007LL;
15 int next[MAXN], extend[MAXN];
16 void get_next(const char *str, int len){
17     // 計算next[0]和next[1]
18     next[0] = len;
19     int i = 0;
20     while(str[i] == str[i + 1] && i + 1 < len) {
21         i++;
22     }
23     next[1] = i;
24     int po = 1; //初始化po的位置
25     for(i = 2; i < len; i++) {
26         if(next[i - po] + i < next[po] + po) { //第一種情況,可以直接得到next[i]的值
27             next[i] = next[i - po];
28         } else { //第二種情況,要繼續匹配才能得到next[i]的值
29             int j = next[po] + po - i;
30             if(j < 0) {
31                 j = 0; //如果i > po + next[po],則要從頭開始匹配
32             }
33             while(i + j < len && str[j] == str[j + i]) { //計算next[i]
34                 j++;
35             }
36             next[i] = j;
37             po = i; //更新po的位置
38         }
39     }
40 }
41 void extend_KMP(const char *str, int lens, const char *pattern, int lenp) {
42     get_next(pattern, lenp); // 先計算模式串的next陣列
43     // 計算extend[0]
44     int i = 0;
45     while(str[i] == pattern[i] && i < lenp && i < lens) {
46         i++;
47     }
48     extend[0] = i;
49     int po = 0; // 初始化po的位置
50     for(i = 1; i < lens; i++) {
51         if(next[i - po] + i < extend[po] + po) { //第一種情況,直接可以得到extend[i]的值
52             extend[i] = next[i - po];
53         } else { // 第二種情況,要繼續匹配才能得到extend[i]的值
54             int j = extend[po] + po - i;
55             if(j < 0) {
56                 j = 0; //如果i > extend[po] + po則要從頭開始匹配
57             }
58             while(i + j < lens && j < lenp && str[j + i] == pattern[j]) { // 計算extend[i]
59                 j++;
60             }
61             extend[i] = j;
62             po = i; // 更新po的位置
63         }
64     }
65 }
66 
67 int main() {
68     int T;
69     scanf("%d", &T);
70     while (T--) {
71         scanf("%s%s", S1, S2);
72         int len1 = strlen(S1);
73         int len2 = strlen(S2);
74         std::reverse(S1, S1 + len1);
75         std::reverse(S2, S2 + len2);
76         extend_KMP(S1, len1, S2, len2);
77         memset(d, 0, sizeof(d));
78         for (int i = 0; i < len1; i++) {
79 //            printf("%d ", extend[i]);
80             d[extend[i]]++;
81         }
82         LL total = 0LL;
83         int ans = 0;
84         for (int j = len2; j > 0; j--) {
85             total = (total + d[j]) % MOD_NUM;
86             ans = (ans + total * j) % MOD_NUM;
87         }
88         printf("%d\n", ans);
89 //        putchar('\n');
90     }
91     return 0;
92 }

最後,在做此題的過程中我還突然發現輸入外掛沒用了,加上輸入外掛後的執行時間比直接用scanf更長。也許是因為現在的oj用上了最新的編譯器吧。