1. 程式人生 > 其它 >ICPC2021 瀋陽站 M String Problem

ICPC2021 瀋陽站 M String Problem

牛客傳送門


KMP的做法暫時沒看懂,這裡提供兩種SAM的做法。
感謝櫻花豬開喵喵車創大白熊新手上路兩隊的程式碼提供的思路。


第一種做法稍麻煩一些:

對於每一個字首,字典序最大的子串一定是該字首的一個字尾,而比較這些字尾的方法就是選擇這些字尾中,最靠前的不同的字元。如果將原串反過來,就可以用SAM維護了。

將反串建成SAM,然後對於字尾連結樹上每一個節點\(u\)的出邊\(v_i\),按\(endpos[v] - len[u]\)在原串中的字元排序,這樣就能優先訪問字典序更大的子串了。

現在對於每個字首都要求對應的最大字尾。可以倒著做:先將所有節點以dfs序為關鍵字扔到一個大根堆中,因為dfs序大的節點代表的子串一定大,那麼如果當前堆頂代表的子串在列舉的當前字首的範圍內,那麼這個子串就是答案,否則將堆頂彈出,再取堆中最大的元素。

這樣時間複雜度是\(O(n \log n)\),需要稍微加一些常數優化才能通過。

#include<bits/stdc++.h>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e6 + 5;
const int maxs = 27;
In ll read()
{
	ll ans = 0;
	char ch = getchar(), las = ' ';
	while(!isdigit(ch)) las = ch, ch = getchar();
	while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
	if(las == '-') ans = -ans;
	return ans;
}
In void write(ll x)
{
	if(x < 0) x = -x, putchar('-');
	if(x >= 10) write(x / 10);
	putchar(x % 10 + '0');
}

int n, ans[maxn];
char s[maxn];
struct Sam
{
	int tra[maxn << 1][maxs], link[maxn << 1], len[maxn << 1], endp[maxn << 1], cnt, las;
	In void init() {link[cnt = las = 0] = -1; Mem(tra[0], 0);}
	In void insert(int c, int id)
	{
		int now = ++cnt, p = las; Mem(tra[now], 0);
		len[now] = len[p] + 1, endp[now] = id;
		while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
		if(p == -1) link[now] = 0;
		else
		{
			int q = tra[p][c];
			if(len[q] == len[p] + 1) link[now] = q;
			else
			{
				int clo = ++cnt;
				memcpy(tra[clo], tra[q], sizeof(tra[q]));
				len[clo] = len[p] + 1, endp[clo] = endp[q];
				link[clo] = link[q], link[q] = link[now] = clo;
				while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
			}
		}
		las = now;
	}
	#define pr pair<int, int>
	#define mp make_pair
	#define F first
	#define S second
	int buc[maxn << 1], pos[maxn << 1];
	vector<pr> V[maxn << 1];
	int du[maxn << 1], dfn[maxn << 1], dcnt;
	In void dfs(int now)
	{
		dfn[now] = ++dcnt;
		for(auto x : V[now]) dfs(x.S);
	}
	In void buildGraph()
	{
		for(int i = 1; i <= cnt; ++i) buc[len[i]]++;
		for(int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
		for(int i = 1; i <= cnt; ++i) pos[buc[len[i]]--] = i;
		endp[0] = INF;
		for(int i = cnt; i; --i)
		{
			int now = pos[i], fa = link[now];
			du[fa]++;
			endp[fa] = min(endp[fa], endp[now]);
			V[fa].push_back(mp(s[endp[now] + len[fa]], now));
		}
		for(int i = 0; i <= cnt; ++i) sort(V[i].begin(), V[i].end());
		dcnt = 0, dfs(0);
	}
	In void solve()
	{
		priority_queue<pr> q;
		for(int i = 1; i <= cnt; ++i) if(!du[i]) q.push(mp(dfn[i], i));
		for(int i = n, now = 0; i; --i)
		{
			while(!ans[i])
			{
				if(!now) now = q.top().S;		//減少堆操作來優化常數 
				if(endp[now] + len[link[now]] > i)
				{
					q.pop();
					if(now && !--du[link[now]]) q.push(mp(dfn[link[now]], link[now]));
					now = 0;
				}
				else ans[i] = endp[now];
			}
		}
	}
}S;

int main()
{
	scanf("%s",s + 1);
	n = strlen(s + 1); S.init();
	for(int i = n; i; --i) S.insert(s[i] - 'a', i);
	S.buildGraph(), 
	S.solve();
	for(int i = 1; i <= n; ++i) write(ans[i]), space, write(i), enter;
	return 0;
}

第二種做法程式碼量相對來說短了不少,我認為是對暴力的一種優化。

首先這題一種\(O(n^2)\)的暴力做法是取出所有子串,並按字典序總大到小排序,記一個子串是\(S_{l \sim r}\),那麼\(ans[r]\)的答案就是第一個出現的\(S_{l \sim r}\)

用SAM優化這個方法:用正串建完SAM後,貪心的在SAM上跑字典序最大的子串,那麼第一個走到該節點的子串一定是最大的。又因為在同一個節點的子串結束位置相同,而且經過這個節點到達別的節點形成的子串字首相同,所以後來經過這個節點形成的子串一定比第一次經過的要小,那麼走過的節點就不用再走了。

時間複雜度就是\(O(27n)\)

.

#include<bits/stdc++.h>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e6 + 5;
const int maxs = 27;
In ll read()
{
	ll ans = 0;
	char ch = getchar(), las = ' ';
	while(!isdigit(ch)) las = ch, ch = getchar();
	while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
	if(las == '-') ans = -ans;
	return ans;
}
In void write(ll x)
{
	if(x < 0) x = -x, putchar('-');
	if(x >= 10) write(x / 10);
	putchar(x % 10 + '0');
}

int n, ans[maxn];
char s[maxn];
struct Sam
{
	int tra[maxn << 1][maxs], link[maxn << 1], len[maxn << 1], endp[maxn << 1], cnt, las;
	In void init() {link[cnt = las = 0] = -1; Mem(tra[0], 0);}
	In void insert(int c, int id)
	{
		int now = ++cnt, p = las; Mem(tra[now], 0);
		len[now] = len[p] + 1, endp[now] = id;
		while(~p && !tra[p][c]) tra[p][c] = now, p = link[p];
		if(p == -1) link[now] = 0;
		else
		{
			int q = tra[p][c];
			if(len[q] == len[p] + 1) link[now] = q;
			else
			{
				int clo = ++cnt;
				memcpy(tra[clo], tra[q], sizeof(tra[q]));
				len[clo] = len[p] + 1, endp[clo] = endp[q];
				link[clo] = link[q], link[q] = link[now] = clo;
				while(~p && tra[p][c] == q) tra[p][c] = clo, p = link[p];
			}
		}
		las = now;
	}
	bool vis[maxn << 1];
	In void dfs(int now, int l)				//l:最大子串開始位置 
	{
		vis[now] = 1;
		for(int i = 25; i >= 0; --i)		//在SAM貪心地走最大的 
			if(tra[now][i] && !vis[tra[now][i]]) dfs(tra[now][i], l + 1);
		if(!ans[endp[now]]) ans[endp[now]] = endp[now] - l + 1;
	}
}S;

int main()
{
	scanf("%s",s + 1);
	n = strlen(s + 1); S.init();
	for(int i = 1; i <= n; ++i) S.insert(s[i] - 'a', i);
	S.dfs(0, 0);
	for(int i = 1; i <= n; ++i) write(ans[i]), space, write(i), enter;
	return 0;
}

還有一個就是kmp的做法,我雖然沒看懂,不過也發一下程式碼吧。

#include<bits/stdc++.h>
using namespace std;
#define enter puts("") 
#define space putchar(' ')
#define Mem(a, x) memset(a, x, sizeof(a))
#define In inline
#define forE(i, x, y) for(int i = head[x], y; ~i && (y = e[i].to); i = e[i].nxt)
typedef long long ll;
typedef double db;
const int INF = 0x3f3f3f3f;
const db eps = 1e-8;
const int maxn = 1e6 + 5;
const int maxs = 27;
In ll read()
{
	ll ans = 0;
	char ch = getchar(), las = ' ';
	while(!isdigit(ch)) las = ch, ch = getchar();
	while(isdigit(ch)) ans = (ans << 1) + (ans << 3) + ch - '0', ch = getchar();
	if(las == '-') ans = -ans;
	return ans;
}
In void write(ll x)
{
	if(x < 0) x = -x, putchar('-');
	if(x >= 10) write(x / 10);
	putchar(x % 10 + '0');
}

int n;
char s[maxn];

vector<int> f, g;

int main()				//好短 
{
	scanf("%s",s + 1);
	n = strlen(s + 1);
	for(int i = 1; i <= n; ++i)
	{
		g.clear(), g.push_back(i);
		for(auto x : f)
		{
			while(!g.empty() && s[x + i - g.back()] > s[i]) g.pop_back();
			if(g.empty() || s[x + i - g.back()] == s[i]) g.push_back(x);
		}
		f.clear();
		for(auto x : g)
		{
			while(!f.empty() && (i - f.back() + 1) * 2 > i - x + 1) f.pop_back();
			f.push_back(x);
		}
		write(f.back()), space, write(i), enter;
	}
	return 0;
}