[Luogu P3181] [BZOJ 4566] [HAOI2016]找相同字元
阿新 • • 發佈:2018-12-10
洛谷傳送門
題目描述
給定兩個字串,求出在兩個字串中各取出一個子串使得這兩個子串相同的方案數。兩個方案不同當且僅當這兩個子串中有一個位置不同。
輸入輸出格式
輸入格式:
兩行,兩個字串,,長度分別為,。,字串中只有小寫字母
輸出格式:
輸出一個整數表示答案
輸入輸出樣例
輸入樣例#1:
aabb
bbaa
輸出樣例#1:
10
解題分析
這道題大概有兩種做法。
第一種是在後綴自動機中插入第一個串後插入一個無法匹配的字元, 再插入第二個串, 並同時記錄兩次插入每個狀態的集合大小。 插入無法匹配的字元的原因是不能讓中間連起來形成合法的子串。
程式碼如下:
#include <cstring>
#include <cstdlib>
#include <cmath>
#include <cctype>
#include <algorithm>
#include <cstdio>
#define R register
#define IN inline
#define W while
#define MX 2000500
char dat[MX];
int to[MX][27] , par[MX], len[MX], siz[2][MX], buc[MX], ind[MX];
int l, cnt, cur, last;
long long ans;
namespace SAM
{
IN void insert(R int id, R int typ)
{
R int now = last, tar;
cur = ++cnt; len[cur] = len[last] + 1; siz[typ][cur] = 1;
for (; (~now) && !to[now][id]; now = par[now]) to[now][id] = cur;
if(now < 0) return par[last = cur] = 0, void();
tar = to[now][id];
if(len[tar] == len[now] + 1) return par[last = cur] = tar, void();
int nw = ++cnt; len[nw] = len[now] + 1;
par[nw] = par[tar], par[tar] = par[last = cur] = nw;
std::memcpy(to[nw], to[tar], sizeof(to[nw]));
for (; (~now) && to[now][id] == tar; now = par[now]) to[now][id] = nw;
}
IN void calc()
{
for (R int i = 1; i <= cnt; ++i) buc[len[i]]++;
for (R int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
for (R int i = 1; i <= cnt; ++i) ind[buc[len[i]]--] = i;
for (R int i = cnt; i; --i)
if(~par[i]) ans += 1ll * (len[ind[i]] - len[par[ind[i]]]) * siz[0][ind[i]] * siz[1][ind[i]],
siz[0][par[ind[i]]] += siz[0][ind[i]], siz[1][par[ind[i]]] += siz[1][ind[i]];
}
}
int main(void)
{
par[0] = -1;
scanf("%s", dat + 1); l = std::strlen(dat + 1);
for (R int i = 1; i <= l; ++i) SAM::insert(dat[i] - 'a', 0);
SAM::insert(26, 0);
scanf("%s", dat + 1); l = std::strlen(dat + 1);
for (R int i = 1; i <= l; ++i) SAM::insert(dat[i] - 'a', 1);
SAM::calc(); printf("%lld\n", ans);
}
另一種是用廣義字尾自動機。 我們在後綴自動機中插入第一個串後將設為初始值, 再插入第二個串。 因為現在每個點的意義是兩個串的公共子串, 所以統計時注意要滿足第二個串的要求, 及時分裂節點。
程式碼如下:
#include <cstdio>
#include <cstring>
#include <cmath>
#include <cctype>
#include <cstdlib>
#define R register
#define IN inline
#define W while
#define MX 1000500
int par[MX], to[MX][26], len[MX], siz[2][MX], buc[MX], ind[MX];
int cnt, l, last, cur;
char dat[MX];
long long ans;
namespace SAM
{
IN void insert(R int ln, R int id, R int typ)
{
R int now = last, tar, sig = 0;
if((!to[now][id]) || (len[to[now][id]] != len[now] + 1)) cur = ++cnt;
else cur = to[now][id], sig = 1;
++siz[typ][cur], last = cur;
if(sig) return;
len[cur] = ln;
for (; (~now) && !to[now][id]; now = par[now]) to[now][id] = cur;
if (now < 0) return par[cur] = 0, void();
tar = to[now][id];
if(len[tar] == len[now] + 1) return par[cur] = tar, void();
int nw = ++cnt; len[nw] = len[now] + 1;
par[nw] = par[tar], par[tar] = par[cur] = nw;
std::memcpy(to[nw], to[tar], sizeof(to[nw]));
for (; (~now) && to[now][id] == tar; now = par[now]) to[now][id] = nw;
}
}
int main(void)
{
par[0] = -1;
scanf("%s", dat + 1); l = std::strlen(dat + 1);
for (R int i = 1; i <= l; ++i) SAM::insert(i, dat[i] - 'a', 0);
scanf("%s", dat + 1); l = std::strlen(dat + 1); last = 0;
for (R int i = 1; i <= l; ++i) SAM::insert(i, dat[i] - 'a', 1);
for (R int i = 1; i <= cnt; ++i) buc[len[i]]++;
for (R int i = 1; i <= cnt; ++i) buc[i] += buc[i - 1];
for (R int i = 1; i <= cnt; ++i) ind[buc[len[i]]--] = i;
for (R int i = cnt; i; --i)
ans += 1ll * siz[0][ind[i]] * siz[1][ind[i]] * (len[ind[i]] - len[par[ind[i]]]),
siz[1][par[ind[i]]] += siz[1][ind[i]], siz[0][par[ind[i]]] += siz[0][ind[i]];
printf("%lld", ans);
}