1. 程式人生 > 實用技巧 >loj3326.「SNOI2020」字串(sa + 並查集)

loj3326.「SNOI2020」字串(sa + 並查集)

loj3326.「SNOI2020」字串

\(Description\)

給定兩個長度為 \(n\) 的小寫字串 \(a, \ b\),求出他們所有長為 \(k\) 的子串,分別組成集合 \(A, \ B\),每次可以修改 \(A\) 中一個元素的字尾,費用為字尾的長度,求將 \(A\) 修改成 \(B\) 的最小費用之和。

\(Data \ Constraint\)

\(1 \leq k, \ n \leq 1.5 \times 10^5\)

考點

\(sa\),並查集,廣義 \(sam\)

\(Solution\)

本題有多種解法,本人寫的是 \(sa\) + 並查集做法(比較慢唔)。

答案可以轉化為 \(k (n - k + 1) - \sum{\text{匹配元素的}lcp}\)

有一個比較顯然的貪心是每次我們從 \(A, \ B\) 中選出一對 \(lcp\) 最大的元素配對。

證明也很簡單:考慮我們當前選了一對 \(lcp\) 最大的元素,顯然我們無法找到另外一對元素使得這兩對元素交叉匹配的 \(lcp\) 之和更大。

\(Method1\)

\(a, \ b\) 拼接起來建 \(sa\),從大到小列舉 \(height\),並查集維護塊及塊中未配對字尾的個數,每次合併兩個塊,若兩塊中未配對的字尾來自不同字串則計算對答案的貢獻,然後合併為配對的字尾。

\(Method2\)

直接建出廣義 \(sam\),簡單樹上 \(dp\) 計算即可。

\(Code(sa)\)

#include <cstdio>
#include <cstring>
#include <algorithm>

using namespace std;

#define N 600000

#define fo(i, x, y) for(int i = x; i <= y; i ++)
#define fd(i, x, y) for(int i = x; i >= y; i --)
#define Mec(a, b) memcpy(a, b, sizeof b)

int sa[N + 1], rk[N + 1], oldrk[N + 1], buc[N + 1], px[N + 1], id[N + 1], ht[N + 1], c[N + 1];

char a[N + 1], b[N + 1];

#define ll long long

int fa[N + 1];

ll sz[N + 1];

struct Arr { int x, y; } d[N + 1];

int n, m, m1 = 26;

void Sort() {
    fill(buc, buc + N, 0);
    fo(i, 1, n) ++ buc[ px[i] = rk[id[i]] ];
    fo(i, 1, m1) buc[i] += buc[i - 1];
    fd(i, n, 1) sa[ buc[px[i]] -- ] = id[i];
}

bool Cmp(int x, int y, int w) { return oldrk[x] == oldrk[y] && oldrk[x + w] == oldrk[y + w]; }

void Get_sa() {
    fo(i, 1, n) c[i] = a[i] - 'a' + 1;
    fo(i, 1, n) c[i + n] = b[i] - 'a' + 1;
    n = (n << 1);
    fo(i, 1, n) rk[ id[i] = i ] = c[i];
    Sort();
    for (int w = 1, p = 0; w <= n; w <<= 1) {
        p = 0;
        fo(i, n - w + 1, n) id[ ++ p ] = i;
        fo(i, 1, n) if (sa[i] > w) id[ ++ p ] = sa[i] - w;
        Sort();
        Mec(oldrk, rk), p = 0;
        fo(i, 1, n)
            rk[sa[i]] = Cmp(sa[i], sa[i - 1], w) ? p : ++ p;
        m1 = p;
    }
    int k = 0;
    fo(i, 1, n) {
        if (k) k --;
        while (c[i + k] == c[sa[rk[i] - 1] + k]) ++ k;
        ht[rk[i]] = k;
    }
}

bool cmp(Arr a, Arr b) { return a.x < b.x; }

int Getf(int u) { return u == fa[u] ? u : fa[u] = Getf(fa[u]); }

bool Pd(int x) { return (x <= (n >> 1) && (n >> 1) - x + 1 >= m) || (x > (n >> 1) && n - x + 1 >= m); } 

int Abs(int x) { return x < 0 ? -x : x; }

int main() {
    scanf("%d %d\n", &n, &m);
    scanf("%s\n%s\n", a + 1, b + 1);

    Get_sa();

    fo(i, 1, n) fa[i] = i, sz[i] = Pd(sa[i]) ? (sa[i] <= (n >> 1) ? 1 : -1) : 0;
    int tot = 0;
    fo(i, 2, n)
        d[ ++ tot ] = (Arr) { ht[i], i };
    sort(d + 1, d + 1 + tot, cmp);
    ll ans = 0;
    fd(i, (n >> 1), 1) {
        while (d[tot].x == i) {
            int u = Getf(d[tot].y), v = Getf(d[tot].y - 1);
            if (sz[u] * sz[v] < 0)
                ans += min(Abs(sz[u]), Abs(sz[v])) * min(i, m);
            fa[v] = u;
            sz[u] += sz[v];
            -- tot;
        }
    }
    printf("%lld\n", 1ll * m * ((n >> 1) - m + 1) - ans);

    return 0;
}