1. 程式人生 > 實用技巧 >P4248 [AHOI2013]差異

P4248 [AHOI2013]差異

P4248 [AHOI2013]差異

題面

傳送門

題解

把式子拆了, 成了兩部分,

\(\sum len(T) - 2\sum_{1≤i<j≤n} lca(T_i, T_j)\)

第一部分就是長度 n * (n - 1) * (n + 1) >> 1

第二部分是所有字串lca長的和的兩倍, sam求一下就好了, (算lca, 你反轉一下, 不就是等價類的minlen嗎?)

注意對於\(T_i, T_j\)他們共同的endposs可能是新的複製點nq

但是按照rk順序, T_i or T_j 其中之一已經併到了 endposs nq中了, 直接算即可

我個憨憨, 直接在 endposs nq(\(T_i, T_j\)

) 中算的貢獻, 那麼這個endposs在轉移到下一個endposs fa[nq](\(T_i, T_j\))的時候貢獻又被算了一次

\(T_i, T_j\) 被算了多次, 且每次被計算的貢獻是 endposs的長度 答案肯定是錯的

要在$ T_i or T_j $其中一方進入 endposs, 一方沒進入 endposs 的時候算 \(T_i ,T_j\) 的貢獻才是真正的貢獻

struct SAM { //不管是不是多組資料都呼叫init
    static const int N = 5e5 + 5, M = 26, C = 'a';
    struct node { int fa, len, ne[M]; } tr[N << 1];
    int sz, las, len, c[N], rk[N << 1], cnt[N << 1];//(i~len)有cnt[i]個字母a[i]
    int sum[N << 1]; //排名為i的節點為頭包含的字串數量
    int ans[N << 1], f[N << 1];
    void init() {
        rep(i, 1, sz)
            tr[i].len = tr[i].fa = c[i] = 0, memset(tr[i].ne, 0, sizeof tr[i].ne);
        sz = las = 1;
    }
    void add(int ch) {
        int p = las, cur = las = ++sz;
        tr[cur].len = tr[p].len + 1; ++cnt[cur];
        for (; p && !tr[p].ne[ch]; p = tr[p].fa) tr[p].ne[ch] = cur;
        if (p == 0) { tr[cur].fa = 1; return; }
        int q = tr[p].ne[ch];
        if (tr[q].len == tr[p].len + 1) { tr[cur].fa = q; return; }
        int nq = ++sz; tr[nq] = tr[q]; tr[nq].len = tr[p].len + 1;
        for (; p && tr[p].ne[ch] == q; p = tr[p].fa) tr[p].ne[ch] = nq;
        tr[q].fa = tr[cur].fa = nq;
    }
    void build(char* s) {
        for (int& i = len; s[i]; ++i) add(s[i] - C);
    }
    void sort() {
        rep(i, 1, sz) c[i] = 0;
        rep(i, 1, sz) ++c[tr[i].len];
        rep(i, 1, len) c[i] += c[i - 1];
        rep(i, 1, sz) rk[c[tr[i].len]--] = i;
    }
    ll solve(char* s, int n) {
        ll ans = (n - 1ll) * (n + 1ll) * n >> 1;
        per(i, sz, 1) ans -= (ll)cnt[tr[rk[i]].fa] * cnt[rk[i]] * tr[tr[rk[i]].fa].len << 1, cnt[tr[rk[i]].fa] += cnt[rk[i]];
        return ans;
    }
} sam;

const int N = 5e5 + 5, inf = 0x3f3f3f3f;

int n, m, _, k;
char s[N];

int main() {
    sam.init(); cin >> s; m = strlen(s); reverse(s, s + m); sam.build(s);
    sam.sort(); cout << sam.solve(s, m);
    return 0;
}