1. 程式人生 > 實用技巧 >2020牛客暑期多校訓練營(第一場)A B-Suffix Array

2020牛客暑期多校訓練營(第一場)A B-Suffix Array

首先是題解的做法

#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<climits>
#include<stack>
#include<vector>
#include<queue>
#include<set>
#include<bitset>
#include<map>
//#include<regex>
#include<cstdio>
#define up(i,a,b)  for(int i=a;i<b;i++)
#define dw(i,a,b)  for(int i=a;i>b;i--)
#define upd(i,a,b) for(int i=a;i<=b;i++)
#define dwd(i,a,b) for(int i=a;i>=b;i--)
//#define local
typedef long long ll;
typedef unsigned long long ull;
const double esp = 1e-6;
const double pi = acos(-1.0);
const int INF = 0x3f3f3f3f;
const int inf = 1e9;
using namespace std;
ll read()
{
    char ch = getchar(); ll x = 0, f = 1;
    while (ch<'0' || ch>'9') { if (ch == '-')f = -1; ch = getchar(); }
    while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
    return x * f;
}
typedef pair<int, int> pir;
#define lson l,mid,root<<1
#define rson mid+1,r,root<<1|1
#define lrt root<<1
#define rrt root<<1|1
const int N = 1e5 + 10;
char ss[2 * N];
int n;
int a[2 * N];
int rk[2 * N], sa[2 * N], t[2 * N], t2[2 * N], c[2 * N], height[2 * N];
int nxt[N * 2][3];
void da(int *s,int sz)
{
    int *x = t, *y = t2;
    int m = sz;
    up(i, 0, m)c[i] = 0;
    up(i, 0, n)c[x[i] = s[i]]++;
    up(i, 1, m)c[i] += c[i - 1];
    dwd(i, n - 1, 0)sa[--c[x[i]]] = i;
    for (int k = 1; k <= n; k <<= 1)
    {
        int p = 0;
        up(i, n - k, n)y[p++] = i;
        up(i, 0, n)if (sa[i] >= k)y[p++] = sa[i] - k;
        up(i, 0, m)c[i] = 0;
        up(i, 0, n)c[x[y[i]]]++;
        up(i, 1, m)c[i] += c[i - 1];
        dwd(i, n - 1, 0)sa[--c[x[y[i]]]] = y[i];
        swap(x, y);
        x[sa[0]] = 0;
        p = 1;
        up(i, 1, n)
            x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++;
        if (p >= n)break;
        m = p;
        //cout << ")" << endl;
    }
}
void test(int *s)
{
    cout << "suffix:" << endl;
    up(i, 0, n)
    {
        //cout << "sa" << sa[i] << endl;
        up(j, sa[i], n)cout << s[j] << " ";
        cout << endl;
        cout <<"h: "<< height[i] << endl;
    }
}
int main()
{
    while (~scanf("%d\n", &n))
    {
        scanf("%s", ss + 1);
        upd(i, 0, n + 1)up(j, 0, 2)nxt[i][j] = 0;
        dwd(i, n, 1)
        {
            up(j, 0, 2)nxt[i][j] = nxt[i + 1][j];
            a[i - 1] = nxt[i][ss[i] - 'a'] ? n - nxt[i][ss[i] - 'a'] + i : 1;
            nxt[i][ss[i] - 'a'] = i;
        }
        a[n] = 0;
        n++;
        da(a, n + 5);
        //test(a);
        upd(i, 1, n-1)
        {
            printf("%d ", sa[i] + 1);
        }
        printf("\n");
    }
    return 0;
}

再說說正常人能想到的做法。
首先觀察,其實這個字串本質上是01串,替換a為0,替換b為1。
可以發現,字串形如11110000111000111000交替出現。
在發現,當算字尾的B函式的時候,字尾的值要麼只會從原來的值突變為0,要麼不變。
我們考慮字串111011101。考慮兩個字首,111011101,11101,發現他的B函式分別為012021242,01202,可以發現,他的B函式的字首相同,0120,所以這個時候,需要比較111011101和11101中,第一個字串字串11101和第二個字串的字尾1的大小。我們可以大膽假設,該字串通過B函式排序,僅僅之和當前(例如1)數字,與後面不和他一樣的數字的距離有關。
1.當距離不等時
有形如111011101,1101,B函式前兩位都是01,第三位開始不同,一個是1,一個是0,發現距離越近,B函式越小。
2.當距離相等時
有111011101,11101,此時後0後面的字尾有關。於是我們先通過後綴數字進行排序,然後通過求sa函式得到。
可以簡單證明正確性:
當有連續個1(或者0)時,函式值為01111111....,知道出現和他不同的數字,這個時候函式值是01111...1110,我們可以發現01兩個數字都已經出現過了,故後面的數字一定不是0,且,當進行字尾排序的時候,他在當前字尾中的函式值,就是B函式的值。比如原陣列1110111,我們觀察字尾10111,後3個1求得的函式值,和B函式值相等(因為在他之前01都已經出現過了,不會再出現因為他之前的0或者1被去掉,從而影響當前位置成為0,即改變函式值)。
所以第一步,我們求整個陣列的B函式,進行字尾排序。然後我們重新計算一個值,即當前位置的下一個,和自己不一樣的位置在哪裡。(這裡如果使用計算長度的話,會有邊界問題)然後我們需要考慮,末尾的情況,因為末尾假設是000的話,沒有下一個1。這個時候設定一個虛擬的節點在n+1位置,令他的rk值即排名為-1。保證後面為空的時候,比其他任何字尾要小。

#include<iostream>
#include<cstring>
#include<algorithm>
#include<cmath>
#include<cstdlib>
#include<climits>
#include<stack>
#include<vector>
#include<queue>
#include<set>
#include<bitset>
#include<map>
//#include<regex>
#include<cstdio>
#define up(i,a,b)  for(int i=a;i<b;i++)
#define dw(i,a,b)  for(int i=a;i>b;i--)
#define upd(i,a,b) for(int i=a;i<=b;i++)
#define dwd(i,a,b) for(int i=a;i>=b;i--)
//#define local
typedef long long ll;
typedef unsigned long long ull;
const double esp = 1e-6;
const double pi = acos(-1.0);
const int INF = 0x3f3f3f3f;
const int inf = 1e9;
using namespace std;
ll read()
{
	char ch = getchar(); ll x = 0, f = 1;
	while (ch<'0' || ch>'9') { if (ch == '-')f = -1; ch = getchar(); }
	while (ch >= '0' && ch <= '9') { x = x * 10 + ch - '0'; ch = getchar(); }
	return x * f;
}
typedef pair<int, int> pir;
#define lson l,mid,root<<1
#define rson mid+1,r,root<<1|1
#define lrt root<<1
#define rrt root<<1|1
const int N = 1e5 + 10;
char ss[2 * N];
int n;
int a[2 * N];
int rk[2 * N], sa[2 * N], t[2 * N], t2[2 * N], c[2 * N], height[2 * N];
int id[2 * N], dis[2 * N];
void da(int *s, int sz)
{
	int *x = t, *y = t2;
	int m = sz;
	up(i, 0, m)c[i] = 0;
	up(i, 0, n)c[x[i] = s[i]]++;
	up(i, 1, m)c[i] += c[i - 1];
	dwd(i, n - 1, 0)sa[--c[x[i]]] = i;
	for (int k = 1; k <= n; k <<= 1)
	{
		int p = 0;
		up(i, n - k, n)y[p++] = i;
		up(i, 0, n)if (sa[i] >= k)y[p++] = sa[i] - k;
		up(i, 0, m)c[i] = 0;
		up(i, 0, n)c[x[y[i]]]++;
		up(i, 1, m)c[i] += c[i - 1];
		dwd(i, n - 1, 0)sa[--c[x[y[i]]]] = y[i];
		swap(x, y);
		x[sa[0]] = 0;
		p = 1;
		up(i, 1, n)
			x[sa[i]] = y[sa[i - 1]] == y[sa[i]] && y[sa[i - 1] + k] == y[sa[i] + k] ? p - 1 : p++;
		if (p >= n)break;
		m = p;
		//cout << ")" << endl;
	}
}
void getheight()
{
	up(i, 0, n)rk[sa[i]] = i;
	int j = 0;
	int k = 0;
	up(i, 0, n)
	{
		if (k)k--;
		j = sa[rk[i] - 1];
		while (a[i + k] == a[j + k])k++;
		height[rk[i]] = k;
	}
}
void test(int *s)
{
	cout << "suffix:" << endl;
	up(i, 0, n)
	{
		cout << "sa" << sa[i] << endl;
		up(j, sa[i], n)cout << s[j] << " ";
		cout << endl;
		cout << "h: " << height[i] << endl;
	}
}
bool cmp(int a, int b)
{
	if (dis[a]-a == dis[b]-b)
	{
		return rk[dis[a]] < rk[dis[b]];
	}
	else return	dis[a] - a < dis[b] - b;
}
int main()
{
	while (~scanf("%d\n", &n))
	{
		scanf("%s", ss + 1);
		int posa = -1, posb = -1;
		upd(i, 1, n)
		{
			if (ss[i] == 'a')
			{
				if (posa == -1)a[i-1] = 1;
				else a[i-1] = i - posa + 1;
				posa = i;
			}
			else {
				if (posb == -1)a[i-1] = 1;
				else a[i-1] = i - posb + 1;
				posb = i;
			}
		}
		a[n] = 0;
		n += 1;
		da(a, n + 5);
		getheight();
		//test(a);
		n -= 1;
		upd(i, 1, n)id[i] = i;
		posa = -1; posb = -1;
		dwd(i, n,1)
		{
			if(ss[i]=='a')
			{
				if (posb == -1)dis[i] = n + 1;
				else dis[i] = posb;
				posa = i;
			}
			else {
				if (posa == -1)dis[i] = n + 1;
				else dis[i] = posa;
				posb = i;
			}
		}
		rk[n + 1] = -1;
		sort(id + 1, id + 1 + n,cmp);
		upd(i, 1, n)
			printf("%d ", id[i]);
		printf("\n");
	}
	return 0;
}