1. 程式人生 > 實用技巧 >可持久化Trie - AcWing - 256 - 最大異或和

可持久化Trie - AcWing - 256 - 最大異或和

可持久化Trie

原題連結:AcWing-256-最大異或和

題解:

可以用字首陣列O(1)解決區間異或的詢問

\[A_i\oplus A_{i+1} \oplus A_{i+2} \oplus ... \oplus A_j = (A_1 \oplus A_2 \oplus ... \oplus A_j) \oplus (A_1 \oplus A_2 \oplus A_3 \oplus...\oplus A_{i-1}) \]

int arr[N]; // 原陣列
int pre[N]; // 字首異或陣列
for(int i = 1; i <= n; i++){
    scanf("%d",&arr[i]);
}
pre[1] = arr[1];
for(int i = 2; i <= n; i++){
	pre[i] = pre[i-1] ^ arr[i];
}

對於題目的每一個Q詢問,我們找到一個這樣的p使得\(p \in [l,r]\),且\(A_p \oplus A_{p+1} \oplus ... \oplus A_N \oplus x\)最大

我們可以通過剛才得到的字首陣列來將問題簡化成找到一個這樣的p使得\(p \in [l,r]\),且pre[p-1] ^ pre[n] ^ x最大

而對於每一個獨立的詢問,pre[n] ^ x是固定不變的,不妨設y = pre[n] ^ x,因此我們最終的問題就是如何找到最優的p使得pre[p-1] ^ y 最大

我們從二進位制的角度考慮這個問題. ^是按位異或,我們要想找到最大的pre[p-1]^y,肯定是基於貪心的思路,優先找高位異或為1的pre[p-1].不妨建立一顆01Trie

來維護. 對於每一個給定的y,我們從01Trie的樹根和y的最高位開始,優先選擇高位異或為1的結點.最終遍歷到Trie的樹根就可以得到最優答案.(這個貪心如果不理解,可以先做AcWing - 143 - 最大異或對)

接下來的問題是,如何使得p從[l,r]區間中進行選擇?不妨由淺入深分析這個問題:

  • 假設只詢問\(p\in [1,n]\),那麼只需要最基本的Trie就可以維護了.
  • 假設詢問\(p\in[1,r] , r \leqslant n\),那麼說明我們需要維護歷史版本的Trie,這一點可以用可持久化Trie實現.
  • 現在要詢問\(p \in [l,r],\)也就是說我們需要在歷史版本為r的基礎上,篩除版本在[1,l-1]
    的數.假設我們取出版本為r的Trie如下圖所示.

假設我們要求\(p \in [5,12]\),那麼對於每一個結點,至少出現一次5以上的版本我才能夠訪問這個結點. 這個問題等價於: 某個結點的最大歷史版本號大於或等於5才能訪問這個結點.因此,我們不妨對於每一個結點記錄其最大歷史版本號,就能夠解決左端點的限制問題.

整理一下思路,我們的解題步驟為:

  • 對於單次加點操作: 更新字首異或陣列 (複雜度為O(1)) 然後更新可持久化Trie,新增新版本(常數複雜度,32左右)
  • 對於單次詢問操作: 從r版本Trie樹根開始,依次向下層深入,比較每個結點的最大版本號與l的大小,貪心.(常數複雜度,32左右)

因此整個演算法的複雜度為O(n)級別,大概執行3e5 * 32次

#include <cstdio>
#include <cstdlib>
#include <vector>
using namespace std;
#define LENGTH 23
#define MAX(a,b) (a>b?a:b)
const int N = 6e5+10;
int len;           // 當前版本號
int idx;           // 索引分配器
int pre[N];   // 記錄各個版本的字首異或陣列
int roots[N]; // 記錄各個版本的樹根
int maxTime[LENGTH*N];
int nex[LENGTH*N][2];

void insert(int x){     // 向字典樹中加入一個值 
    pre[1+len] = pre[len] ^ x; 
    ++len;

    ++idx;
    roots[len] = idx;
    maxTime[idx] = len;

    int rt = idx;
    int pos = 1 << LENGTH;

    if(len == 1){ // 如果當前版本號為1
        for(int i = LENGTH; i >= 0; i--){
            ++idx;
            if(pre[len] & pos){
                nex[rt][1] = idx;
                rt = idx;
            }else{
                nex[rt][0] = idx;
                rt = idx;
            }
            pos >>= 1;
            maxTime[rt] = 1;
        }
    }else{
        int flag = 1;
        int oldRoot = roots[len-1];
        for(int i = LENGTH; i >= 0; i--){
            ++idx;
            if(pre[len] & pos){
                nex[rt][1] = idx;
                if(flag){
                    nex[rt][0] = nex[oldRoot][0];
                    oldRoot = nex[oldRoot][1];
                    if(!oldRoot){
                        flag = 0;
                    }
                }
                rt = idx;
            }else{
                nex[rt][0] = idx;
                if(flag){
                    nex[rt][1] = nex[oldRoot][1];
                    oldRoot = nex[oldRoot][0];
                    if(oldRoot == -1){
                        flag = 0;
                    }
                }
                rt = idx;
            }
            pos >>= 1;
            maxTime[rt] = len;
        }
    }
}
int query(int l,int r,int x){
    x ^= pre[len];
    int pos = 1 << LENGTH;
    int v = 0;
    int rt = roots[r];
    for(int i = LENGTH; i >= 0; i--){
        if(x & pos){
            if(nex[rt][0] && maxTime[nex[rt][0]]>= l){
                v <<= 1;
                rt = nex[rt][0];
            }else{
                v <<= 1;
                v |= 1;
                rt = nex[rt][1];
            }
        }else{
            if(nex[rt][1] && maxTime[nex[rt][1]] >= l){
                v <<= 1;
                v |= 1;
                rt = nex[rt][1];
            }else{
                v <<= 1;
                rt = nex[rt][0];
            }
        }
        pos >>= 1;
    }

    return x ^ v;
}
int main(){
    int n,m,l,r,temp;
    char cmd[5];
    scanf("%d%d",&n,&m);
    for(int i = 1; i <= n; i++){
        scanf("%d",&temp);
        insert(temp);
    }

    for(int i = 1; i <= m; i++){
        scanf("%s",cmd);
        if(cmd[0]=='A'){
            scanf("%d",&temp);
            insert(temp);
        }else{
            scanf("%d%d%d",&l,&r,&temp);
            l--;
            r--;
            if(r == 0){
                printf("%d\n",pre[len]^temp);
            }else{
                printf("%d\n",query(l,r,temp));
            }
        }
    }
    system("pause");
    return 0;
}