1. 程式人生 > 其它 >01 Trie 專項題解

01 Trie 專項題解

思維路徑:

  • xor 運算的特性
    • “是否不同”
    • 相同會消掉
    • 字首和思想(處理區間查詢、樹上路徑查詢)
  • 逐位處理
  • 從高位到低位貪心

HDU 4825 Xor Sum

板子題。

將所有數字二進位制從高位到低位插入 Trie 中,從高到低貪心地能取到 1 就取 1.

const int MAXN = 100000 + 10;

namespace Trie {
struct Node {
    int nxt[2]; Node() { nxt[0] = nxt[1] = 0; }
} node[MAXN * 34]; int root = 1, cnt = 1;

void Insert(lli x) {
    int u = root;
    for (int bit = 32; bit >= 0; --bit) {
        int &nxt = node[u].nxt[(x >> bit) & 1];
        if (!nxt) nxt = ++cnt;
        u = nxt;
    }
}
lli Query(lli x) {
    lli ret = 0; int u = root;
    for (int bit = 32; bit >= 0; --bit) {
        int xb = (x >> bit) & 1;
        int nxt = 0;
        ret <<= 1;
        if (node[u].nxt[xb ^ 1]) {
            nxt = node[u].nxt[xb ^ 1];
            ret |= xb ^ 1;
        } else {
            nxt = node[u].nxt[xb];
            ret |= xb;
        }
        u = nxt;
    }
    return ret;
}
void clear() {
    memset(node, 0, sizeof node); root = cnt = 1;
}
}

int n, m;

int main() {
    int T = read();
    for (int cas = 1; cas <= T; ++cas) {
        printf("Case #%d:\n", cas);
        n = read(); m = read();
        rep (i, 1, n) Trie::Insert(readll());
        rep (i, 1, m) printf("%lld\n", Trie::Query(readll()));
        Trie::clear();
    }
    return 0;
}

HDU 5536 Chip Factory

基本思路仍然是把一個東西插入 Trie,再用另一個查詢。

\(a_i + a_j\) 插入 Trie 會佔用大量空間,但總時間複雜度仍然不變,所以不如把 \(a_k\) 插入後列舉 \(a_i + a_j\)

關鍵是在於如何判斷 \(i \neq j \neq k\)

這裡可以藉助樹形 DP 的一個小技巧(類似補集轉換):求出所有的,減去不要的,就是要求的。

所以我們在列舉 \(a_i + a_j\) 時,對 \(a_i\)\(a_j\) 打一個標記,表示這個點不能走,求完了再恢復回來即可。

因為一條邊能經過多個數字,所以這個標記用 size

記錄比較舒服,貪心的時候先判 size != 0

const int MAXN = 1000 + 10;

namespace Trie {
struct Node {
    int nxt[2]; int passby; Node() { nxt[0] = nxt[1] = passby = 0; }
} node[MAXN * 34]; int root = 1, cnt = 1;

void Insert(lli x, int dt) {
    int u = root;
    for (int bit = 32; bit >= 0; --bit) {
        node[u].passby += dt;
        int &nxt = node[u].nxt[(x >> bit) & 1];
        if (!nxt) nxt = ++cnt;
        u = nxt;
    }
    node[u].passby += dt;
}
lli Query(lli x) {
    lli ret = 0; int u = root;
    for (int bit = 32; bit >= 0; --bit) {
        int xb = (x >> bit) & 1;
        int nxt = 0;
        ret <<= 1;
        if (node[u].nxt[xb ^ 1] && node[node[u].nxt[xb ^ 1]].passby) {
            nxt = node[u].nxt[xb ^ 1];
            ret |= 1;
        } else {
            nxt = node[u].nxt[xb];
        }
        u = nxt;
    }
    return ret;
}
void clear() {
    for (int i = 1; i <= cnt; ++i) node[i] = Node();
    root = cnt = 1;
}
}

int n, m;
int aa[MAXN];

int main() {
    int T = read();
    for (int cas = 1; cas <= T; ++cas) {
        n = read();
        rep (i, 1, n) aa[i] = read();
        rep (i, 1, n) Trie::Insert(aa[i], 1);
        lli ans = 0;
        for (int i = 1; i <= n; ++i) {
            for (int j = i + 1; j <= n; ++j) {
                Trie::Insert(aa[i], -1); Trie::Insert(aa[j], -1);
                ans = std::max(ans, Trie::Query(aa[i] + aa[j]));
                Trie::Insert(aa[i], 1); Trie::Insert(aa[j], 1);
            }
        }
        printf("%lld\n", ans);
        Trie::clear();
    }
    return 0;
}

BZOJ 4260 Codechef REBXOR

首先這個區間查詢可以用一個字首優化搞掉,\(a_i \oplus \dots \oplus a_j = s_j \oplus s_{i - 1}\)

所以答案就是讓我們分別求兩組 \(l, r\) 滿足 \(r_1 < l_2\),而且各自的 \(s_r \oplus s_{l - 1}\) 最大。

後面這個東西,固定一個端點可以求出前 / 字尾中選另一個端點的最大值,類似權值樹狀陣列的思想。

為了這個 \(r_1 < l_2\),我們可以考慮對於每一個點求出它前後綴的最大答案,兩個相鄰的加一下就是最終答案了。

const int MAXN = 4e5 + 10;

int n;
int aa[MAXN];
int pref[MAXN];
int suff[MAXN];

struct Trie {
    struct Node {
        int nxt[2];
    } node[MAXN * 32]; int root = 1, cnt = 1;

    void Insert(int x) {
        int u = root;
        for (int bit = 30; bit >= 0; --bit) {
            int xb = (x >> bit) & 1;
            int &nxt = node[u].nxt[xb];
            if (!nxt) nxt = ++cnt;
            u = nxt;
        }
    }
    int Query(int x) {
        int ret = 0; int u = root;
        for (int bit = 30; bit >= 0; --bit) {
            int xb = (x >> bit) & 1;
            int nxt = 0;
            ret <<= 1;
            if (node[u].nxt[xb ^ 1]) {
                nxt = node[u].nxt[xb ^ 1];
                ret |= 1;
            } else nxt = node[u].nxt[xb];
            u = nxt;
        }
        return ret;
    }
} prefs, suffs;

int prefans[MAXN], suffans[MAXN];

int main() {
    n = read();
    prefs.Insert(0); suffs.Insert(0);
    rep (i, 1, n) aa[i] = read();
    rep (i, 1, n) {
        pref[i] = pref[i - 1] ^ aa[i];
        int maxp = prefs.Query(pref[i]);
        prefans[i] = maxp;
        prefans[i] = std::max(prefans[i], prefans[i - 1]);
        prefs.Insert(pref[i]);
    }
    for (int i = n; i >= 1; --i) {
        suff[i] = suff[i + 1] ^ aa[i];
        int maxs = suffs.Query(suff[i]);
        suffans[i] = maxs;
        suffans[i] = std::max(suffans[i], suffans[i + 1]);
        suffs.Insert(suff[i]);
    }
    int ans = 0;
    for (int i = 1; i <= n - 1; ++i) {
        ans = std::max(ans, prefans[i] + suffans[i + 1]);
    }
    printf("%d\n", ans);
    return 0;
}

POJ 3764 The xor-longest Path

和上道題差不多的思想,使用樹上字首異或和 + 邊插入邊查詢。

兩個點之間路徑異或和就是它們到根節點異或和做異或。

對整棵樹做 DFS,依次把所有點到根節點的路徑異或和插入 Trie 裡,邊插入邊查詢異或最大值,這個做法顯然是可以遍歷到所有路徑的,具體可以通過列舉 LCA 來證明。

聽他們說用 vector 會 T,我也沒試過。

const int MAXN = 1e5 + 10;

int n;

struct Edge {
    int v, nxt; lli w;
} edge[MAXN << 1]; int head[MAXN], cnt;

void addEdge(int u, int v, lli w) {
    edge[++cnt] = {v, head[u], w}; head[u] = cnt;
}

namespace Trie {
struct Node { int nxt[2]; Node() {nxt[0] = nxt[1] = 0;} } node[MAXN * 34];
int root = 1, cnt = 1;

void clear() {
    for (int i = 1; i <= cnt; ++i) node[i] = Node();
    cnt = 1;
}
void Insert(lli x) {
    int u = root;
    for (int bit = 32; bit >= 0; --bit) {
        int xb = (x >> bit) & 1;
        int &nxt = node[u].nxt[xb];
        if (!nxt) nxt = ++cnt;
        u = nxt;
    }
}
lli Query(lli x) {
    lli ret = 0;
    int u = root;
    for (int bit = 32; bit >= 0; --bit) {
        int xb = (x >> bit) & 1;
        int nxt = 0;
        ret <<= 1;
        if (node[u].nxt[xb ^ 1]) {
            nxt = node[u].nxt[xb ^ 1];
            ret |= 1;
        } else nxt = node[u].nxt[xb];
        u = nxt;
    } return ret;
}
}

lli ans = 0;
void dfs(int u, int fa, lli pref) {
    ans = std::max(ans, Trie::Query(pref));
    Trie::Insert(pref);
    for (int e = head[u]; e; e = edge[e].nxt) {
        int v = edge[e].v; lli w = edge[e].w;
        if (v == fa) continue;
        dfs(v, u, pref ^ w);
    }
}
void cleanup() {
    ans = 0;
    cnt = 0; for (int i = 1; i <= n; ++i) head[i] = 0;
    Trie::clear();
}
int _main() {
    Trie::Insert(0);
    for (int i = 1; i <= n - 1; ++i) {
        int u = read() + 1; int v = read() + 1; lli w = readll();
        addEdge(u, v, w); addEdge(v, u, w);
    }
    dfs(1, 0, 0);
    printf("%lld\n", ans);
    cleanup();
    return 0;
}
int main() {
    while (scanf("%d", &n) != EOF) _main();
}

Codeforces 842D Vitya and Strange Lesson

還挺有意思的這題。

首先全域性異或可以用一個標記記住,因為異或具有結合律。

然後就是這個鬼畜的全域性 Mex。

假設沒有全域性異或這個鬼東西,只讓我們求全域性 Mex,那麼我們可以用一個桶存下來所有數,求這個東西就相當於是在求桶的第一個空位是誰,有一個 \(\log\) 做法是使用權值線段樹上二分,每次判斷左區間是不是滿的(也就是 \(\mathrm{sum[lson]} = \mathrm{mid}\)),如果不是滿的就說明 Mex 在左邊,否則就在右邊。然而權值線段樹無法維護全域性 xor。

這個做法的關鍵是能快速求出左區間有多少數字(因此也不能出現重複數字,必須提前去重),於是我們需要一個支援查詢左區間 size 和逐位異或的資料結構,而這個東西可以用 01 Trie 來完成。

題外話,考慮到數字 Trie 從高位向低位存、不足位補前導零的這個特性,每一個數字的長度都是相同的,也就是說,所有數字都可以通過一個葉子結點唯一確定。因此,它類似於一個權值資料結構,也能完成一些權值資料結構的操作(求 rank,求前驅後繼),只不過單次操作複雜度是固定的 \(O(\mathrm{len})\),而權值線段樹、樹狀陣列等的單次操作複雜度是固定的 \(O(\log M)\),其中 \(M\) 是值域。
這個特性也決定了,在將 \(1\sim n\) 的所有數字插入時,樹的結構成為滿多叉樹,空間複雜度會達到指數級別。

先考慮沒有異或(或者異或的這一位是 0)的情況,和上面類似,對於第 \(b(b\geq 0)\) 位我們先求出 0 子樹的 size(可以在插入的時候記錄一下),然後再和滿的大小 \(2^b\) 判斷一下,如果不滿就往左走,否則就往右走。我們可以不建滿二叉樹,如果當前的 0 子樹不存在,就直接 return。

有異或的話,直接把 0 當 1 看,1 當 0 看即可。

const int MAXN = 3e5 + 10;

int n, m;

int addition;

namespace Trie {
struct Node { int nxt[2]; int siz; Node() { nxt[0] = nxt[1] = siz = 0; } } node[MAXN * 20];
int root = 1, cnt = 1;

void Insert(int x) {
    int u = root;
    for (int bit = 21; bit >= 0; --bit) {
        ++node[u].siz;
        int xb = (x >> bit) & 1;
        int &nxt = node[u].nxt[xb];
        if (!nxt) nxt = ++cnt;
        u = nxt;
    }
    ++node[u].siz;
}
int Query() {
    int u = root;
    int ret = 0;
    for (int bit = 21; bit >= 0; --bit) {
        ret <<= 1;
        int xb = (addition >> bit) & 1;
        int nxt = 0;
        if (!node[u].nxt[xb]) {
            ret <<= bit; return ret; // 一定別忘了 << bit
        }
        if (node[node[u].nxt[xb]].siz < (1 << bit)) nxt = node[u].nxt[xb];
        else {
            nxt = node[u].nxt[xb ^ 1]; ret |= 1;
        }
        u = nxt; // 一定別忘了 u = nxt
    }
    return ret;
}
}

int uniq[MAXN];

int main() {
    n = read(); m = read();
    rep (i, 1, n) {
        int x = read();
        if (!uniq[x]) Trie::Insert(x);
        uniq[x] = 1;
    }
    while (m --> 0) {
        addition ^= read();
        printf("%d\n", Trie::Query());
    }
    return 0;
}

Codeforces 713A Sonya and Queries

最後來一道水題收尾。

\(a, b\) 奇偶性相同等價於 \(a \equiv b\ (\bmod 2)\),所以就把所有的數字都逐位模 2 插進 Trie 即可。

const int MAXN = 1e5 + 10;

namespace Trie {
struct Node { int nxt[2], siz; Node() { nxt[0] = nxt[1] = siz = 0; } } node[MAXN * 20];
int root = 1, cnt = 1;

void Insert(lli x, int dt) {
    int u = root;
    for (int bit = 19; bit >= 0; --bit) {
        int &nxt = node[u].nxt[(x >> bit) & 1];
        if (!nxt) nxt = ++cnt;
        u = nxt;
    }
    node[u].siz += dt;
}
int Query(lli x) {
    int u = root;
    for (int bit = 19; bit >= 0; --bit) {
        u = node[u].nxt[(x >> bit) & 1];
    }
    return node[u].siz;
}
}

lli Process(lli x) {
    std::vector<int> res;
    while (x) {
        res.push_back((x % 10) % 2); x /= 10;
    }
    std::reverse(ALL(res));
    lli ret = 0;
    for (auto v : res) ret = (ret << 1) + v;
    return ret;
}

int t;

int main() {
    t = read();
    while (t --> 0) {
        char ss[2]; scanf("%s", ss);
        lli fx = readll();
        switch (ss[0]) {
            case '+': {
                Trie::Insert(Process(fx), 1);
                break;
            }
            case '-': {
                Trie::Insert(Process(fx), -1);
                break;
            }
            case '?': {
                printf("%d\n", Trie::Query(Process(fx)));
            }
        }
    }
    return 0;
}

易錯警示

  1. 老生常談的 typo。
  2. 求完 nxt 一定不要忘了更新 u = nxt
  3. node 陣列大小一定要乘上數字長度,或者乾脆就再寫一個 const int MAXNODE
  4. 多測清空的時候不要忘了 cnt = 1