1. 程式人生 > >[Luogu P3181] [BZOJ 4566] [HAOI2016]找相同字元

[Luogu P3181] [BZOJ 4566] [HAOI2016]找相同字元

洛谷傳送門

題目描述

給定兩個字串,求出在兩個字串中各取出一個子串使得這兩個子串相同的方案數。兩個方案不同當且僅當這兩個子串中有一個位置不同。

輸入輸出格式

輸入格式:

兩行,兩個字串s1s_1s2s_2,長度分別為n1n_1n2n_21n1,n22000001 \le n_1, n_2\le 200000,字串中只有小寫字母

輸出格式:

輸出一個整數表示答案

輸入輸出樣例

輸入樣例#1:

aabb
bbaa

輸出樣例#1:

10

解題分析

這道題大概有兩種做法。

第一種是在後綴自動機中插入第一個串後插入一個無法匹配的字元, 再插入第二個串, 並同時記錄兩次插入每個狀態的r

ightright集合大小。 插入無法匹配的字元的原因是不能讓中間連起來形成合法的子串。

程式碼如下:

#include <cstring>
#include <cstdlib>
#include <cmath>
#include <cctype>
#include <algorithm>
#include <cstdio>
#define R register
#define IN inline
#define W while
#define MX 2000500
char dat[MX];
int to[MX][27]
, par[MX], len[MX], siz[2][MX], buc[MX], ind[MX]; int l, cnt, cur, last; long long ans; namespace SAM { IN void insert(R int id, R int typ) { R int now = last, tar; cur = ++cnt; len[cur] = len[last] + 1; siz[typ][cur] = 1; for (; (~now) && !to[now][id]; now = par[now]) to[now][id] = cur;
if(now < 0) return par[last = cur] = 0, void(); tar = to[now][id]; if(len[tar] == len[now] + 1) return par[last = cur] = tar, void(); int nw = ++cnt; len[nw] = len[now] + 1; par[nw] = par[tar], par[tar] = par[last = cur] = nw; std::memcpy(to[nw], to[tar], sizeof(to[nw])); for (; (~now) && to[now][id] == tar; now = par[now]) to[now][id] = nw; } IN void calc() { for (R int i = 1; i <= cnt; ++i) buc[len[i]]++; for (R int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1]; for (R int i = 1; i <= cnt; ++i) ind[buc[len[i]]--] = i; for (R int i = cnt; i; --i) if(~par[i]) ans += 1ll * (len[ind[i]] - len[par[ind[i]]]) * siz[0][ind[i]] * siz[1][ind[i]], siz[0][par[ind[i]]] += siz[0][ind[i]], siz[1][par[ind[i]]] += siz[1][ind[i]]; } } int main(void) { par[0] = -1; scanf("%s", dat + 1); l = std::strlen(dat + 1); for (R int i = 1; i <= l; ++i) SAM::insert(dat[i] - 'a', 0); SAM::insert(26, 0); scanf("%s", dat + 1); l = std::strlen(dat + 1); for (R int i = 1; i <= l; ++i) SAM::insert(dat[i] - 'a', 1); SAM::calc(); printf("%lld\n", ans); }

另一種是用廣義字尾自動機。 我們在後綴自動機中插入第一個串後將lastlast設為初始值, 再插入第二個串。 因為現在每個點的意義是兩個串的公共子串, 所以統計時注意lenlen要滿足第二個串的要求, 及時分裂節點。

程式碼如下:

#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <cstdlib>
#define R register
#define IN inline
#define W while
#define MX 1000500
int par[MX], to[MX][26], len[MX], siz[2][MX], buc[MX], ind[MX];
int cnt, l, last, cur;
char dat[MX];
long long ans;
namespace SAM
{
	IN void insert(R int ln, R int id, R int typ)
	{
		R int now = last, tar, sig = 0;
		if((!to[now][id]) || (len[to[now][id]] != len[now] + 1)) cur = ++cnt;
		else cur = to[now][id], sig = 1;
		++siz[typ][cur], last = cur;
		if(sig) return;
		len[cur] = ln;
		for (; (~now) && !to[now][id]; now = par[now]) to[now][id] = cur;
		if (now < 0) return par[cur] = 0, void();
		tar = to[now][id];
		if(len[tar] == len[now] + 1) return par[cur] = tar, void();
		int nw = ++cnt; len[nw] = len[now] + 1;
		par[nw] = par[tar], par[tar] = par[cur] = nw;
		std::memcpy(to[nw], to[tar], sizeof(to[nw]));
		for (; (~now) && to[now][id] == tar; now = par[now]) to[now][id] = nw;
	}
}
int main(void)
{
	par[0] = -1;
	scanf("%s", dat + 1); l = std::strlen(dat + 1);
	for (R int i = 1; i <= l; ++i) SAM::insert(i, dat[i] - 'a', 0);
	scanf("%s", dat + 1); l = std::strlen(dat + 1); last = 0;
	for (R int i = 1; i <= l; ++i) SAM::insert(i, dat[i] - 'a', 1);
	for (R int i = 1; i <= cnt; ++i) buc[len[i]]++;
	for (R int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
	for (R int i = 1; i <= cnt; ++i) ind[buc[len[i]]--] = i;
	for (R int i = cnt; i; --i)
	ans += 1ll * siz[0][ind[i]] * siz[1][ind[i]] * (len[ind[i]] - len[par[ind[i]]]),
	siz[1][par[ind[i]]] += siz[1][ind[i]], siz[0][par[ind[i]]] += siz[0][ind[i]];
	printf("%lld", ans);
}