loj3326.「SNOI2020」字串(sa + 並查集)
阿新 • • 發佈:2020-12-24
loj3326.「SNOI2020」字串
\(Description\)
給定兩個長度為 \(n\) 的小寫字串 \(a, \ b\),求出他們所有長為 \(k\) 的子串,分別組成集合 \(A, \ B\),每次可以修改 \(A\) 中一個元素的字尾,費用為字尾的長度,求將 \(A\) 修改成 \(B\) 的最小費用之和。
\(Data \ Constraint\)
\(1 \leq k, \ n \leq 1.5 \times 10^5\)。
考點
\(sa\),並查集,廣義 \(sam\)。
\(Solution\)
本題有多種解法,本人寫的是 \(sa\) + 並查集做法(比較慢唔)。
答案可以轉化為 \(k (n - k + 1) - \sum{\text{匹配元素的}lcp}\),
有一個比較顯然的貪心是每次我們從 \(A, \ B\) 中選出一對 \(lcp\) 最大的元素配對。
證明也很簡單:考慮我們當前選了一對 \(lcp\) 最大的元素,顯然我們無法找到另外一對元素使得這兩對元素交叉匹配的 \(lcp\) 之和更大。
\(Method1\)
將 \(a, \ b\) 拼接起來建 \(sa\),從大到小列舉 \(height\),並查集維護塊及塊中未配對字尾的個數,每次合併兩個塊,若兩塊中未配對的字尾來自不同字串則計算對答案的貢獻,然後合併為配對的字尾。
\(Method2\)
直接建出廣義 \(sam\),簡單樹上 \(dp\) 計算即可。
\(Code(sa)\)
#include <cstdio> #include <cstring> #include <algorithm> using namespace std; #define N 600000 #define fo(i, x, y) for(int i = x; i <= y; i ++) #define fd(i, x, y) for(int i = x; i >= y; i --) #define Mec(a, b) memcpy(a, b, sizeof b) int sa[N + 1], rk[N + 1], oldrk[N + 1], buc[N + 1], px[N + 1], id[N + 1], ht[N + 1], c[N + 1]; char a[N + 1], b[N + 1]; #define ll long long int fa[N + 1]; ll sz[N + 1]; struct Arr { int x, y; } d[N + 1]; int n, m, m1 = 26; void Sort() { fill(buc, buc + N, 0); fo(i, 1, n) ++ buc[ px[i] = rk[id[i]] ]; fo(i, 1, m1) buc[i] += buc[i - 1]; fd(i, n, 1) sa[ buc[px[i]] -- ] = id[i]; } bool Cmp(int x, int y, int w) { return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w]; } void Get_sa() { fo(i, 1, n) c[i] = a[i] - 'a' + 1; fo(i, 1, n) c[i + n] = b[i] - 'a' + 1; n = (n << 1); fo(i, 1, n) rk[ id[i] = i ] = c[i]; Sort(); for (int w = 1, p = 0; w <= n; w <<= 1) { p = 0; fo(i, n - w + 1, n) id[ ++ p ] = i; fo(i, 1, n) if (sa[i] > w) id[ ++ p ] = sa[i] - w; Sort(); Mec(oldrk, rk), p = 0; fo(i, 1, n) rk[sa[i]] = Cmp(sa[i], sa[i - 1], w) ? p : ++ p; m1 = p; } int k = 0; fo(i, 1, n) { if (k) k --; while (c[i + k] == c[sa[rk[i] - 1] + k]) ++ k; ht[rk[i]] = k; } } bool cmp(Arr a, Arr b) { return a.x < b.x; } int Getf(int u) { return u == fa[u] ? u : fa[u] = Getf(fa[u]); } bool Pd(int x) { return (x <= (n >> 1) && (n >> 1) - x + 1 >= m) || (x > (n >> 1) && n - x + 1 >= m); } int Abs(int x) { return x < 0 ? -x : x; } int main() { scanf("%d %d\n", &n, &m); scanf("%s\n%s\n", a + 1, b + 1); Get_sa(); fo(i, 1, n) fa[i] = i, sz[i] = Pd(sa[i]) ? (sa[i] <= (n >> 1) ? 1 : -1) : 0; int tot = 0; fo(i, 2, n) d[ ++ tot ] = (Arr) { ht[i], i }; sort(d + 1, d + 1 + tot, cmp); ll ans = 0; fd(i, (n >> 1), 1) { while (d[tot].x == i) { int u = Getf(d[tot].y), v = Getf(d[tot].y - 1); if (sz[u] * sz[v] < 0) ans += min(Abs(sz[u]), Abs(sz[v])) * min(i, m); fa[v] = u; sz[u] += sz[v]; -- tot; } } printf("%lld\n", 1ll * m * ((n >> 1) - m + 1) - ans); return 0; }