1. 程式人生 > 其它 >【雜湊,長鏈剖分】CF504E Misha and LCP on Tree

【雜湊,長鏈剖分】CF504E Misha and LCP on Tree

原題面夠簡潔了:CF504E Misha and LCP on Tree

考慮經典套路二分+雜湊,雜湊先正著反著預處理一遍, check 時看當前答案 mid 落在哪條路徑的哪個點上,然後一通加減即可。

首先 \(m \log n\) 用來求 \(LCA\) 與 二分,瓶頸在 check 。又因為其是找樹上 k 級祖先,就用長剖做到 \(O(1)\) 查詢。

總複雜度 \(O((n + m) \log n\) 。麻了,模數寫 33333331 被卡了,但是改成 998244353 卻沒事了,不應該是前者挺小眾麼(......

#include <cstdio>
#include <algorithm>
#include <vector>
#define pb push_back
using namespace std;
typedef long long LL;
int read() {
    char ch = getchar();
    int x = 0, pd = 0;
    while (ch < '0' || ch > '9')
        pd ^= ch == '-', ch = getchar();
    while ('0' <= ch && ch <= '9')
        x = x*10+(ch^48), ch = getchar();
    return pd ? -x : x;
}
#define mod 33333331
#define base 26
int add(int x, int y) { return x + y < mod ? x + y : x + y - mod; }
int mul(int x, int y) { return 1ll * x * y % mod; }
int Pow(int x, int y) {
    int res = 1;
    for (; y; y >>= 1, x = mul(x, x))
        if (y & 1) res = mul(res, x);
    return res;
}
const int maxn = 300005;
int p[maxn], invp[maxn];
void pre() {
    p[0] = 1;
    for (int i = 1; i <= maxn - 3; i++) p[i] = mul(p[i - 1], base);
    invp[0] = 1, invp[1] = Pow(base, mod - 2);
    for (int i = 2; i <= maxn - 3; i++) invp[i] = mul(invp[i - 1], invp[1]);
}
int n, m;
char s[maxn];
vector<int> to[maxn];
int lg[maxn], fa[maxn][20], d[maxn], son[maxn], mxd[maxn];
int h1[maxn], h2[maxn];
void dfs1(int u) {
    h1[u] = add(h1[u], mul(s[u] - 'a', p[d[u]])), h2[u] = add(h2[u], s[u] - 'a');
    mxd[u] = d[u];
    for (int i = 1; i < 19; i++) fa[u][i] = fa[fa[u][i - 1]][i - 1];
    for (int v : to[u]) if (v != fa[u][0]) {
        d[v] = d[u] + 1, h1[v] = h1[u], h2[v] = mul(h2[u], base);
        fa[v][0] = u, dfs1(v);
        if (mxd[v] > mxd[son[u]]) son[u] = v;
        mxd[u] = max(mxd[u], mxd[v]);
    }
}
int top[maxn];
vector<int> U[maxn], D[maxn];
void dfs2(int u, int t) {
    top[u] = t;
    if (son[u]) dfs2(son[u], t);
    else return;
    for (int v : to[u]) if (v != son[u] && v != fa[u][0]) dfs2(v, v);
    if (t == u) {
        for (int i = u, j = 0; i && j <= mxd[u] - d[u]; i = fa[i][0], j++) U[u].pb(i);
        for (int i = u; i; i = son[i]) D[u].pb(i);
    }
}
int getlca(int x, int y) {
    if (d[x] > d[y]) swap(x, y);
    for (int i = d[y] - d[x]; i; i ^= 1 << lg[i]) y = y[fa][lg[i]];
    if (x == y) return x;
    for (int i = 18; i >= 0; i--)
        if (fa[x][i] != fa[y][i]) x = fa[x][i], y = fa[y][i];
    return fa[x][0];
}
int getKfa(int x, int k) {
    if (!k) return x;
    int y = x[fa][lg[k]][top];
    k -= d[x] - d[y];
    return k > 0 ? U[y][k] : D[y][-k];
}
int getH(int x, int y, int lca, int Len) {
    int res = 0;
    if (!Len) return 0;
    if (Len <= d[x] - d[lca] + 1) {
        int pos = getKfa(x, Len - 1);
        return mul(add(h1[x], mod - h1[fa[pos][0]]), invp[d[pos]]);
    }
    else {
        res = mul(add(h1[x], mod - h1[fa[lca][0]]), invp[d[lca]]);
        Len -= d[x] - d[lca] + 1;
        int pos = getKfa(y, d[y] - d[lca] - Len);
        res = add(mul(res, p[Len]), add(h2[pos], mod - mul(h2[lca], p[Len])));
    }
    return res;
}
int main() {
    #ifdef LOCAL
    freopen("data.in", "r", stdin);
    freopen("pro.out", "w", stdout);
    #endif
    pre();
    n = read(), scanf("%s", s + 1);
    for (int i = 1; i < n; i++) {
        int x = read(), y = read();
        to[x].pb(y), to[y].pb(x);
    }
    dfs1(1), dfs2(1, 1);
    for (int i = 2; i <= n; i++) lg[i] = lg[i >> 1] + 1;
    m = read();
    while (m--) {
        int A = read(), B = read(), C = read(), D = read();
        int lca1 = getlca(A, B), lca2 = getlca(C, D);
        int l = 0, r = min(d[A] + d[B] - 2 * d[lca1], d[C] + d[D] - 2 * d[lca2]) + 2;
        while (l + 1 < r) {
            int mid = (l + r) >> 1;
            if (getH(A, B, lca1, mid) != getH(C, D, lca2, mid)) r = mid;
            else l = mid;
        } printf("%d\n", l);
    }
    return 0;
}