1. 程式人生 > >0-1字典樹總結和經典例題

0-1字典樹總結和經典例題

0-1字典樹:

0-1字典樹主要用於解決求異或最值的問題,0-1字典樹其實就是一個二叉樹,和普通的字典樹原理類似,只不過把插入字元改成了插入二進位制串的每一位(0或1)。

下面先給出0-1字典樹的簡單模板:

LL val[32 * MaxN]; //點的值 
int ch[32 * MaxN][2]; //邊的值 
int tot; //節點個數 
 
void add(LL x) { //往 01字典樹中插入 x 
    int u = 0;
    for(int i = 32; i >= 0; i--) {
        int v = (x >> i) & 1;
        if(!ch[u][v]) { //如果節點未被訪問過 
            ch[tot][0] = ch[tot][1] = 0; //將當前節點的邊值初始化 
            val[tot] = 0; //節點值為0,表示到此不是一個數 
            ch[u][v] = tot++; //邊指向的節點編號 
        }
        u = ch[u][v]; //下一節點 
    }
    val[u] = x; //節點值為 x,即到此是一個數 
}
 
LL query(LL x) { 
    int u = 0;
    for(int i = 32; i >= 0; i--) {
        int v = (x >> i) & 1;
        //利用貪心策略,優先尋找和當前位不同的數 
        if(ch[u][v^1]) u = ch[u][v^1];
        else u = ch[u][v];
    }
    return val[u]; //返回結果 
}

不難發現以下事實:

  • 01字典樹是一棵最多32層的二叉樹,其每個節點的兩條邊分別表示二進位制的某一位的值為 0 還是為 1。將某個路徑上邊的值連起來就得到一個二進位制串。
  • 節點個數為 1 的層(最高層,也就是根節點)節點的邊對應著二進位制串的最高位,向下的每一層逐位降低。
  • 以上程式碼中,ch[i] 表示一個節點,ch[i][0] 和 ch[i][1] 表示節點的兩條邊指向的節點,val[i] 表示節點的值。
  • 每個節點主要有4個屬性:節點值、節點編號、兩條邊指向的下一節點的編號。
  • 節點值 val 為 0時表示到當前節點為止不能形成一個數,否則 val[i] = 數值。
  • 節點編號在程式執行時生成,無規律。
  • 可通過貪心的策略來尋找與 x 異或結果最大的數,即優先找和 x 二進位制的未處理的最高位值不同的邊對應的點,這樣保證結果最大。

複雜度:O(32*n)

 

例題1. CSU 1216:異或最大值

http://acm.csu.edu.cn/csuoj/problemset/problem?pid=1216

Description

給定一些數,求這些數中兩個數的異或值最大的那個值。

(對於一個長度為 n 的陣列a1, a2, …, an,請找出不同的 i, j,使 ai ^ aj 的值最大)

Input

多組資料

。第一行為數字個數n,1 <= n <= 10 ^ 5。接下來n行每行一個32位有符號非負整數。

Output

任意兩數最大異或值

 

Solution:

貪心找最大異或值: 

把每一個數以二進位制形式從高位到低位插入trie樹中,依次列舉每個數,在trie中貪心,即當前為0則向1走,為1則向0走。

異或運算有一個性質,就是對應位不一樣則為1,要使結果最大化,就要讓越高的位為1,所以找與一個數使得兩數的異或結果最大,就需要從樹的根結點(也就是最高位)開始找,如果對應位置的這個數是0,優先去找那一位為1的數,否則再找0;同理,如果對應位置的這個數是1,優先去找那一位為0的數,否則再找1。最終找到的數就是跟這個數異或結果最大的數。

對於n個數,每個數找一個這樣的數並算出結果求其中的最大值即可。

 

Code:

#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define fi first
#define se second
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int INF = 0x3f3f3f3f;
const double eps = 1e-9;
const int Mod = 1e9 + 7;
const int MaxN = 1e5 + 5;

LL a[MaxN];
LL val[32 * MaxN]; 
int ch[32 * MaxN][2];  
int tot; 
 
void add(LL x) { 
    int u = 0;
    for(int i = 32; i >= 0; i--) {
        int v = (x >> i) & 1;
        if(!ch[u][v]) {  
            ch[tot][0] = ch[tot][1] = 0;  
            val[tot] = 0;  
            ch[u][v] = tot++; 
        }
        u = ch[u][v]; 
    }
    val[u] = x; 
}
 
LL query(LL x) { 
    int u = 0;
    for(int i = 32; i >= 0; i--) {
        int v = (x >> i) & 1;
        if(ch[u][v^1]) u = ch[u][v^1];
        else u = ch[u][v];
    }
    return val[u];  
}

int main(){
	int n; 
	while(cin >> n) {
		ch[0][0] = ch[0][1] = 0; 
		tot = 1;
		for(int i = 1; i <= n; i++) {
			cin >> a[i];
			add(a[i]);
		}
		LL ans = 0LL;
		for(int i = 1; i <= n; i++) {
			ans = max(ans, a[i] ^ query(a[i]));
		}
		cout << ans << endl;
	}
	return 0;
}

 

例題2. HDU 4825 Xor Sum

http://acm.hdu.edu.cn/showproblem.php?pid=4825

Problem Description

Zeus 和 Prometheus 做了一個遊戲,Prometheus 給 Zeus 一個集合,集合中包含了N個正整數,隨後 Prometheus 將向 Zeus 發起M次詢問,每次詢問中包含一個正整數 S ,之後 Zeus 需要在集合當中找出一個正整數 K ,使得 K 與 S 的異或結果最大。Prometheus 為了讓 Zeus 看到人類的偉大,隨即同意 Zeus 可以向人類求助。你能證明人類的智慧麼?

Input

輸入包含若干組測試資料,每組測試資料包含若干行。
輸入的第一行是一個整數T(T < 10),表示共有T組資料。
每組資料的第一行輸入兩個正整數N,M(<1=N,M<=100000),接下來一行,包含N個正整數,代表 Zeus 的獲得的集合,之後M行,每行一個正整數S,代表 Prometheus 詢問的正整數。所有正整數均不超過2^32。

Output

對於每組資料,首先需要輸出單獨一行”Case #?:”,其中問號處應填入當前的資料組數,組數從1開始計算。
對於每個詢問,輸出一個正整數K,使得K與S異或值最大。

 

Description:

m組詢問,每次詢問給出一個數,求在n個數中找出一個數,使得與當前數的異或結果最大。

 

Solution:

與上一題基本一樣

 

Code:

#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define fi first
#define se second
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int INF = 0x3f3f3f3f;
const double eps = 1e-9;
const int Mod = 1e9 + 7;
const int MaxN = 1e5 + 5;

LL a[MaxN];
LL val[32 * MaxN]; 
int ch[32 * MaxN][2];  
int tot; 

void add(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		if(!ch[u][v]) {  
			ch[tot][0] = ch[tot][1] = 0;  
			val[tot] = 0;  
			ch[u][v] = tot++; 
		}
		u = ch[u][v]; 
	}
	val[u] = x; 
}

LL query(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		if(ch[u][v^1]) u = ch[u][v^1];
		else u = ch[u][v];
	}
	return val[u];  
}

int main(){
	int t; cin >> t;
	for(int cas = 1; cas <= t; cas++) {
		ch[0][0] = ch[0][1] = 0; 
		tot = 1;
		int n, m; cin >> n >> m;
		for(int i = 1; i <= n; i++) {
			cin >> a[i];
			add(a[i]);
		}
		cout << "Case #" << cas << ":" << endl;
		while(m--) {
			LL x; cin >> x;
			cout << query(x) << endl;
		}
	}
	return 0;
}

 

例題3. HDU 5536 Chip Factory

http://acm.hdu.edu.cn/showproblem.php?pid=5536

Problem Description

John is a manager of a CPU chip factory, the factory produces lots of chips everyday. To manage large amounts of products, every processor has a serial number. More specifically, the factory produces n chips today, the i-th chip produced this day has a serial number si.
At the end of the day, he packages all the chips produced this day, and send it to wholesalers. More specially, he writes a checksum number on the package, this checksum is defined as below:

maxi,j,k(si+sj)⊕sk
which i,j,k are three different integers between 1 and n. And ⊕ is symbol of bitwise XOR.
Can you help John calculate the checksum number of today?

Input

The first line of input contains an integer T indicating the total number of test cases.
The first line of each test case is an integer n, indicating the number of chips produced today. The next line has n integers s1,s2,..,sn, separated with single space, indicating serial number of each chip.
1≤T≤1000
3≤n≤1000
0≤si≤109
There are at most 10 testcases with n>100

Output

For each test case, please output an integer indicating the checksum number in a line.

 

Description:

在一個數組中找出 (s[i] + s[j]) ^ s[k] 的最大值,其中 i、j、k 各不相同。

 

Solution:

由於題目中的資料範圍很小,可以暴力列舉 i 和 j,與上面的例題不同的是,由於規定 i, j, k 各不相同,所以需要增加一個 update 操作,用來記錄增加或減少一個數後每個節點的訪問次數,通過訪問次數是否大於0判斷當前數是否被使用過(也就是a[i], a[j])。

 

Code:

#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
const int MaxN = 1e5 + 5;

LL a[MaxN];
LL val[32 * MaxN]; 
int ch[32 * MaxN][2], vis[32 * MaxN];  
int tot; 

void add(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		if(!ch[u][v]) {  
			ch[tot][0] = ch[tot][1] = 0;  
			val[tot] = 0;  
			vis[tot] = 0; //
			ch[u][v] = tot++; 
		}
		u = ch[u][v]; 
		vis[u]++; //
	}
	val[u] = x; 
}

void update(LL x, int add) { //更新插入或刪除x後每個節點被訪問的次數
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		u = ch[u][v];
		vis[u] += add;
	}
}

LL query(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		//if(ch[u][v^1]) u = ch[u][v^1];
		if(ch[u][v^1] && vis[ch[u][v^1]]) u = ch[u][v^1]; //訪問次數大於0說明當前數不是a[i],a[j]
		else u = ch[u][v];
	}
	return val[u];  
}

int main(){
	int t; cin >> t;
	for(int cas = 1; cas <= t; cas++) {
		ch[0][0] = ch[0][1] = 0; 
		tot = 1;
		int n; cin >> n;
		for(int i = 1; i <= n; i++) {
			cin >> a[i];
			add(a[i]);
		}
		LL ans = 0LL;
		for(int i = 1; i <= n; i++) {
			for(int j = 1; j <= n; j++) {
				if(i == j) continue;
				update(a[i], -1);
				update(a[j], -1);
				ans = max(ans, (a[i]+a[j]) ^ query(a[i]+a[j]));
				update(a[i], 1);
				update(a[j], 1);
			}
		}
		cout << ans << endl;
	}
	return 0;
}

 

例題4. BZOJ 4260: Codechef REBXOR

https://www.lydsy.com/JudgeOnline/problem.php?id=4260

Description:

給出 n 個數,求兩個不相交的區間中的元素異或後的和的最大值

 

Solution:

首先考慮異或的一個性質:0 ^ a = a,a ^ a = 0。前 i 個數的異或結果和前 j 個數的異結果再進行異或: pre[i] ^ pre[j] = a[i+1] ^ a[i+2] ^ …^ a[j] (i < j)。異或的字尾和同理。

於是可以通過先求出異或的字首 pre[i] 和字尾 suf[i]。dp[i] 表示前 i 個數中任意區間異或後的最大值,可以依次求與 pre[i] 相異或結果的最大值,然後把 pre[i] 插入到 01字典樹中。

這樣對於每個 pre[i] 就會和之前的 i-1 個異或字首和的共有部分相抵消,也就相當於是求任意區間的異或結果的最大值了。這樣求出了一個區間,同理可利用字尾和求出另一個區間。

那麼如何保證兩個區間不相交呢?可以通過使前後兩個區間一個為不包含第 i 個數的前部分割槽間,一個是包含第 i 個數的後部分割槽間就可以了。

 

Code:

#include <cstdio>
#include <cstring>
#include <string>
#include <cmath>
#include <cstdlib>
#include <ctime>
#include <iostream>
#include <algorithm>
#include <vector>
#include <queue>
#include <stack>
#include <set>
#include <map>
#define fi first
#define se second
#define mst(a, b) memset(a, b, sizeof(a))
using namespace std;
typedef long long LL;
typedef pair<int, int> PII;
const int INF = 0x3f3f3f3f;
const double eps = 1e-9;
const int Mod = 1e9 + 7;
const int MaxN = 4e5 + 5;

LL a[MaxN];
LL val[32 * MaxN]; 
int ch[32 * MaxN][2]; //vis[32 * MaxN]; //記錄訪問次數
int tot; 
LL dp[MaxN], pre[MaxN], suf[MaxN];

void add(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		if(!ch[u][v]) {  
			ch[tot][0] = ch[tot][1] = 0;  
			val[tot] = 0;  
			//vis[tot] = 0;
			ch[u][v] = tot++; 
		}
		u = ch[u][v]; 
		//vis[u]++;
	}
	val[u] = x; 
}

void update(LL x, int add) { //更新插入或刪除x後每個節點被訪問的次數
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		u = ch[u][v];
		//vis[u] += add;
	}
}

LL query(LL x) { 
	int u = 0;
	for(int i = 32; i >= 0; i--) {
		int v = (x >> i) & 1;
		if(ch[u][v^1]) u = ch[u][v^1];
		//if(ch[u][v^1] && vis[ch[u][v^1]]) u = ch[u][v^1];
		else u = ch[u][v];
	}
	return x ^ val[u];  
}

int main(){
	ch[0][0] = ch[0][1] = 0; tot = 1;
	int n; cin >> n;
	for(int i = 1; i <= n; i++) cin >> a[i];
	pre[0] = suf[n+1] = 0;
	for(int i = 1; i <= n; i++) pre[i] = pre[i-1] ^ a[i];
	for(int i = n; i >= 1; i--) suf[i] = suf[i+1] ^ a[i];
	mst(dp, 0);
	add(pre[0]);
	for(int i = 1; i <= n; i++) {
		dp[i] = max(dp[i-1], query(pre[i])); //即前i個數的任意區間異或的最大值
		add(pre[i]);
	}
	ch[0][0] = ch[0][1] = 0; tot = 1;
	add(suf[n+1]);
	LL ans = 0;
	for(int i = n; i >= 1; i--) {
		ans = max(ans, query(suf[i]) + dp[i-1]);
		add(suf[i]);
	}
	cout << ans << endl;
return 0;
}

 

例題5. POJ 3764 The xor-longest Path

http://poj.org/problem?id=3764
Description

In an edge-weighted tree, the xor-length of a path p is defined as the xor sum of the weights of edges on p:

⊕ is the xor operator.

We say a path the xor-longest path if it has the largest xor-length. Given an edge-weighted tree with n nodes, can you find the xor-longest path?  

Input

The input contains several test cases. The first line of each test case contains an integer n(1<=n<=100000), The following n-1 lines each contains three integers u(0 <= u < n),v(0 <= v < n),w(0 <= w < 2^31), which means there is an edge between node u and v of length w.

Output

For each test case output the xor-length of the xor-longest path.

 

Description:

在樹上找一段路徑(連續)使得邊權相異或的結果最大。

 

Solution:

 

 

Code: