2020牛客暑期多校訓練營(第一場)A B-Suffix Array
首先是題解的做法
#include<iostream> #include<cstring> #include<algorithm> #include<cmath> #include<cstdlib> #include<climits> #include<stack> #include<vector> #include<queue> #include<set> #include<bitset> #include<map> //#include<regex> #include<cstdio> #define up(i,a,b) for(int i=a;i<b;i++) #define dw(i,a,b) for(int i=a;i>b;i--) #define upd(i,a,b) for(int i=a;i<=b;i++) #define dwd(i,a,b) for(int i=a;i>=b;i--) //#define local typedef long long ll; typedef unsigned long long ull; const double esp = 1e-6; const double pi = acos(-1.0); const int INF = 0x3f3f3f3f; const int inf = 1e9; using namespace std; ll read() { char ch = getchar(); ll x = 0, f = 1; while (ch<'0' || ch>'9') { if (ch == '-')f = -1; ch = getchar(); } while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); } return x * f; } typedef pair<int, int> pir; #define lson l,mid,root<<1 #define rson mid+1,r,root<<1|1 #define lrt root<<1 #define rrt root<<1|1 const int N = 1e5 + 10; char ss[2 * N]; int n; int a[2 * N]; int rk[2 * N], sa[2 * N], t[2 * N], t2[2 * N], c[2 * N], height[2 * N]; int nxt[N * 2][3]; void da(int *s,int sz) { int *x = t, *y = t2; int m = sz; up(i, 0, m)c[i] = 0; up(i, 0, n)c[x[i] = s[i]]++; up(i, 1, m)c[i] += c[i - 1]; dwd(i, n - 1, 0)sa[--c[x[i]]] = i; for (int k = 1; k <= n; k <<= 1) { int p = 0; up(i, n - k, n)y[p++] = i; up(i, 0, n)if (sa[i] >= k)y[p++] = sa[i] - k; up(i, 0, m)c[i] = 0; up(i, 0, n)c[x[y[i]]]++; up(i, 1, m)c[i] += c[i - 1]; dwd(i, n - 1, 0)sa[--c[x[y[i]]]] = y[i]; swap(x, y); x[sa[0]] = 0; p = 1; up(i, 1, n) x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++; if (p >= n)break; m = p; //cout << ")" << endl; } } void test(int *s) { cout << "suffix:" << endl; up(i, 0, n) { //cout << "sa" << sa[i] << endl; up(j, sa[i], n)cout << s[j] << " "; cout << endl; cout <<"h: "<< height[i] << endl; } } int main() { while (~scanf("%d\n", &n)) { scanf("%s", ss + 1); upd(i, 0, n + 1)up(j, 0, 2)nxt[i][j] = 0; dwd(i, n, 1) { up(j, 0, 2)nxt[i][j] = nxt[i + 1][j]; a[i - 1] = nxt[i][ss[i] - 'a'] ? n - nxt[i][ss[i] - 'a'] + i : 1; nxt[i][ss[i] - 'a'] = i; } a[n] = 0; n++; da(a, n + 5); //test(a); upd(i, 1, n-1) { printf("%d ", sa[i] + 1); } printf("\n"); } return 0; }
再說說正常人能想到的做法。
首先觀察,其實這個字串本質上是01串,替換a為0,替換b為1。
可以發現,字串形如11110000111000111000交替出現。
在發現,當算字尾的B函式的時候,字尾的值要麼只會從原來的值突變為0,要麼不變。
我們考慮字串111011101。考慮兩個字首,111011101,11101,發現他的B函式分別為012021242,01202,可以發現,他的B函式的字首相同,0120,所以這個時候,需要比較111011101和11101中,第一個字串字串11101和第二個字串的字尾1的大小。我們可以大膽假設,該字串通過B函式排序,僅僅之和當前(例如1)數字,與後面不和他一樣的數字的距離有關。
1.當距離不等時
有形如111011101,1101,B函式前兩位都是01,第三位開始不同,一個是1,一個是0,發現距離越近,B函式越小。
2.當距離相等時
有111011101,11101,此時後0後面的字尾有關。於是我們先通過後綴數字進行排序,然後通過求sa函式得到。
可以簡單證明正確性:
當有連續個1(或者0)時,函式值為01111111....,知道出現和他不同的數字,這個時候函式值是01111...1110,我們可以發現01兩個數字都已經出現過了,故後面的數字一定不是0,且,當進行字尾排序的時候,他在當前字尾中的函式值,就是B函式的值。比如原陣列1110111,我們觀察字尾10111,後3個1求得的函式值,和B函式值相等(因為在他之前01都已經出現過了,不會再出現因為他之前的0或者1被去掉,從而影響當前位置成為0,即改變函式值)。
所以第一步,我們求整個陣列的B函式,進行字尾排序。然後我們重新計算一個值,即當前位置的下一個,和自己不一樣的位置在哪裡。(這裡如果使用計算長度的話,會有邊界問題)然後我們需要考慮,末尾的情況,因為末尾假設是000的話,沒有下一個1。這個時候設定一個虛擬的節點在n+1位置,令他的rk值即排名為-1。保證後面為空的時候,比其他任何字尾要小。
#include<iostream> #include<cstring> #include<algorithm> #include<cmath> #include<cstdlib> #include<climits> #include<stack> #include<vector> #include<queue> #include<set> #include<bitset> #include<map> //#include<regex> #include<cstdio> #define up(i,a,b) for(int i=a;i<b;i++) #define dw(i,a,b) for(int i=a;i>b;i--) #define upd(i,a,b) for(int i=a;i<=b;i++) #define dwd(i,a,b) for(int i=a;i>=b;i--) //#define local typedef long long ll; typedef unsigned long long ull; const double esp = 1e-6; const double pi = acos(-1.0); const int INF = 0x3f3f3f3f; const int inf = 1e9; using namespace std; ll read() { char ch = getchar(); ll x = 0, f = 1; while (ch<'0' || ch>'9') { if (ch == '-')f = -1; ch = getchar(); } while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); } return x * f; } typedef pair<int, int> pir; #define lson l,mid,root<<1 #define rson mid+1,r,root<<1|1 #define lrt root<<1 #define rrt root<<1|1 const int N = 1e5 + 10; char ss[2 * N]; int n; int a[2 * N]; int rk[2 * N], sa[2 * N], t[2 * N], t2[2 * N], c[2 * N], height[2 * N]; int id[2 * N], dis[2 * N]; void da(int *s, int sz) { int *x = t, *y = t2; int m = sz; up(i, 0, m)c[i] = 0; up(i, 0, n)c[x[i] = s[i]]++; up(i, 1, m)c[i] += c[i - 1]; dwd(i, n - 1, 0)sa[--c[x[i]]] = i; for (int k = 1; k <= n; k <<= 1) { int p = 0; up(i, n - k, n)y[p++] = i; up(i, 0, n)if (sa[i] >= k)y[p++] = sa[i] - k; up(i, 0, m)c[i] = 0; up(i, 0, n)c[x[y[i]]]++; up(i, 1, m)c[i] += c[i - 1]; dwd(i, n - 1, 0)sa[--c[x[y[i]]]] = y[i]; swap(x, y); x[sa[0]] = 0; p = 1; up(i, 1, n) x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++; if (p >= n)break; m = p; //cout << ")" << endl; } } void getheight() { up(i, 0, n)rk[sa[i]] = i; int j = 0; int k = 0; up(i, 0, n) { if (k)k--; j = sa[rk[i] - 1]; while (a[i + k] == a[j + k])k++; height[rk[i]] = k; } } void test(int *s) { cout << "suffix:" << endl; up(i, 0, n) { cout << "sa" << sa[i] << endl; up(j, sa[i], n)cout << s[j] << " "; cout << endl; cout << "h: " << height[i] << endl; } } bool cmp(int a, int b) { if (dis[a]-a == dis[b]-b) { return rk[dis[a]] < rk[dis[b]]; } else return dis[a] - a < dis[b] - b; } int main() { while (~scanf("%d\n", &n)) { scanf("%s", ss + 1); int posa = -1, posb = -1; upd(i, 1, n) { if (ss[i] == 'a') { if (posa == -1)a[i-1] = 1; else a[i-1] = i - posa + 1; posa = i; } else { if (posb == -1)a[i-1] = 1; else a[i-1] = i - posb + 1; posb = i; } } a[n] = 0; n += 1; da(a, n + 5); getheight(); //test(a); n -= 1; upd(i, 1, n)id[i] = i; posa = -1; posb = -1; dwd(i, n,1) { if(ss[i]=='a') { if (posb == -1)dis[i] = n + 1; else dis[i] = posb; posa = i; } else { if (posa == -1)dis[i] = n + 1; else dis[i] = posa; posb = i; } } rk[n + 1] = -1; sort(id + 1, id + 1 + n,cmp); upd(i, 1, n) printf("%d ", id[i]); printf("\n"); } return 0; }